summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
Diffstat (limited to 'src/python')
-rw-r--r--src/python/CMakeLists.txt62
-rw-r--r--src/python/doc/_templates/layout.html2
-rw-r--r--src/python/doc/alpha_complex_sum.inc15
-rw-r--r--src/python/doc/installation.rst5
-rw-r--r--src/python/doc/persistence_graphical_tools_sum.inc8
-rw-r--r--src/python/doc/persistence_graphical_tools_user.rst32
-rw-r--r--src/python/doc/wasserstein_distance_user.rst17
-rw-r--r--src/python/gudhi/hera.cc71
-rw-r--r--src/python/gudhi/nerve_gic.pyx4
-rw-r--r--src/python/gudhi/persistence_graphical_tools.py124
-rw-r--r--src/python/gudhi/wasserstein.py6
-rw-r--r--src/python/setup.py.in21
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py61
13 files changed, 320 insertions, 108 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 798e2907..22af3ec9 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -32,6 +32,10 @@ function( add_gudhi_debug_info DEBUG_INFO )
endfunction( add_gudhi_debug_info )
if(PYTHONINTERP_FOUND)
+ if(PYBIND11_FOUND)
+ add_gudhi_debug_info("Pybind11 version ${PYBIND11_VERSION}")
+ set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'hera', ")
+ endif()
if(CYTHON_FOUND)
set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'off_reader', ")
set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'simplex_tree', ")
@@ -87,6 +91,7 @@ if(PYTHONINTERP_FOUND)
endif(MSVC)
if(CMAKE_COMPILER_IS_GNUCXX)
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-frounding-math', ")
+ set(GUDHI_PYBIND11_EXTRA_COMPILE_ARGS "${GUDHI_PYBIND11_EXTRA_COMPILE_ARGS}'-fvisibility=hidden', ")
endif(CMAKE_COMPILER_IS_GNUCXX)
if (CMAKE_CXX_COMPILER_ID MATCHES Intel)
set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-fp-model strict', ")
@@ -177,7 +182,7 @@ if(PYTHONINTERP_FOUND)
set(GUDHI_PYTHON_LIBRARY_DIRS "${GUDHI_PYTHON_LIBRARY_DIRS}'${MPFR_LIBRARIES_DIR}', ")
message("** Add mpfr ${MPFR_LIBRARIES}")
endif(MPFR_FOUND)
-endif(CGAL_FOUND)
+ endif(CGAL_FOUND)
# Specific for Mac
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
@@ -392,9 +397,9 @@ endif(CGAL_FOUND)
add_gudhi_py_test(test_reader_utils)
# Wasserstein
- if(OT_FOUND)
+ if(OT_FOUND AND PYBIND11_FOUND)
add_gudhi_py_test(test_wasserstein_distance)
- endif(OT_FOUND)
+ endif()
# Representations
if(SKLEARN_FOUND AND MATPLOTLIB_FOUND)
@@ -411,32 +416,37 @@ endif(CGAL_FOUND)
if(SCIPY_FOUND)
if(SKLEARN_FOUND)
if(OT_FOUND)
- if(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
- set (GUDHI_SPHINX_MESSAGE "Generating API documentation with Sphinx in ${CMAKE_CURRENT_BINARY_DIR}/sphinx/")
- # User warning - Sphinx is a static pages generator, and configured to work fine with user_version
- # Images and biblio warnings because not found on developper version
- if (GUDHI_PYTHON_PATH STREQUAL "src/python")
- set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developper version. Images and biblio will miss")
- endif()
- # sphinx target requires gudhi.so, because conf.py reads gudhi version from it
- add_custom_target(sphinx
- WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/doc
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${SPHINX_PATH} -b html ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/sphinx
- DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/gudhi.so"
- COMMENT "${GUDHI_SPHINX_MESSAGE}" VERBATIM)
+ if(PYBIND11_FOUND)
+ if(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
+ set (GUDHI_SPHINX_MESSAGE "Generating API documentation with Sphinx in ${CMAKE_CURRENT_BINARY_DIR}/sphinx/")
+ # User warning - Sphinx is a static pages generator, and configured to work fine with user_version
+ # Images and biblio warnings because not found on developper version
+ if (GUDHI_PYTHON_PATH STREQUAL "src/python")
+ set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developper version. Images and biblio will miss")
+ endif()
+ # sphinx target requires gudhi.so, because conf.py reads gudhi version from it
+ add_custom_target(sphinx
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/doc
+ COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ ${SPHINX_PATH} -b html ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/sphinx
+ DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/gudhi.so"
+ COMMENT "${GUDHI_SPHINX_MESSAGE}" VERBATIM)
- add_test(NAME sphinx_py_test
- WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
- ${SPHINX_PATH} -b doctest ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/doctest)
+ add_test(NAME sphinx_py_test
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}"
+ ${SPHINX_PATH} -b doctest ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/doctest)
- # Set missing or not modules
- set(GUDHI_MODULES ${GUDHI_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MODULES")
- else(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
- message("++ Python documentation module will not be compiled because it requires a Eigen3 and CGAL version >= 4.11.0")
+ # Set missing or not modules
+ set(GUDHI_MODULES ${GUDHI_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MODULES")
+ else(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
+ message("++ Python documentation module will not be compiled because it requires a Eigen3 and CGAL version >= 4.11.0")
+ set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES")
+ endif(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
+ else(PYBIND11_FOUND)
+ message("++ Python documentation module will not be compiled because pybind11 was not found")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES")
- endif(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0)
+ endif(PYBIND11_FOUND)
else(OT_FOUND)
message("++ Python documentation module will not be compiled because POT was not found")
set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES")
diff --git a/src/python/doc/_templates/layout.html b/src/python/doc/_templates/layout.html
index 2f2d9c72..a672a281 100644
--- a/src/python/doc/_templates/layout.html
+++ b/src/python/doc/_templates/layout.html
@@ -201,7 +201,7 @@
<a href="#">Download</a>
<ul class="dropdown">
<li><a href="/licensing/">Licensing</a></li>
- <li><a href="https://gforge.inria.fr/frs/download.php/latestzip/5253/library-latest.zip" target="_blank">Get the latest sources</a></li>
+ <li><a href="https://github.com/GUDHI/gudhi-devel/releases/latest" target="_blank">Get the latest sources</a></li>
<li><a href="/conda/">Conda package</a></li>
<li><a href="/dockerfile/">Dockerfile</a></li>
</ul>
diff --git a/src/python/doc/alpha_complex_sum.inc b/src/python/doc/alpha_complex_sum.inc
index a1184663..b5af0d27 100644
--- a/src/python/doc/alpha_complex_sum.inc
+++ b/src/python/doc/alpha_complex_sum.inc
@@ -5,16 +5,13 @@
| .. figure:: | Alpha complex is a simplicial complex constructed from the finite | :Author: Vincent Rouvreau |
| ../../doc/Alpha_complex/alpha_complex_representation.png | cells of a Delaunay Triangulation. | |
| :alt: Alpha complex representation | | :Introduced in: GUDHI 2.0.0 |
- | :figclass: align-center | The filtration value of each simplex is computed as the square of the | |
- | | circumradius of the simplex if the circumsphere is empty (the simplex | :Copyright: MIT (`GPL v3 </licensing/>`_) |
- | | is then said to be Gabriel), and as the minimum of the filtration | |
- | | values of the codimension 1 cofaces that make it not Gabriel | :Requires: `Eigen <installation.html#eigen>`__ :math:`\geq` 3.1.0 and `CGAL <installation.html#cgal>`__ :math:`\geq` 4.11.0 |
- | | otherwise. All simplices that have a filtration value | |
- | | :math:`> \alpha^2` are removed from the Delaunay complex | |
- | | when creating the simplicial complex if it is specified. | |
+ | :figclass: align-center | The filtration value of each simplex is computed as the **square** of | |
+ | | the circumradius of the simplex if the circumsphere is empty (the | :Copyright: MIT (`GPL v3 </licensing/>`_) |
+ | | simplex is then said to be Gabriel), and as the minimum of the | |
+ | | filtration values of the codimension 1 cofaces that make it not | :Requires: `Eigen <installation.html#eigen>`__ :math:`\geq` 3.1.0 and `CGAL <installation.html#cgal>`__ :math:`\geq` 4.11.0 |
+ | | Gabriel otherwise. | |
| | | |
- | | This package requires having CGAL version 4.7 or higher (4.8.1 is | |
- | | advised for better performance). | |
+ | | For performances reasons, it is advised to use CGAL ≥ 5.0.0. | |
+----------------------------------------------------------------+------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------+
| * :doc:`alpha_complex_user` | * :doc:`alpha_complex_ref` |
+----------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
diff --git a/src/python/doc/installation.rst b/src/python/doc/installation.rst
index 40f3f44b..d459145b 100644
--- a/src/python/doc/installation.rst
+++ b/src/python/doc/installation.rst
@@ -14,10 +14,11 @@ Compiling
*********
The library uses c++14 and requires `Boost <https://www.boost.org/>`_ ≥ 1.56.0,
`CMake <https://www.cmake.org/>`_ ≥ 3.1 to generate makefiles,
-`NumPy <http://numpy.org>`_ and `Cython <https://www.cython.org/>`_ to compile
+`NumPy <http://numpy.org>`_, `Cython <https://www.cython.org/>`_ and
+`pybind11 <https://github.com/pybind/pybind11>`_ to compile
the GUDHI Python module.
It is a multi-platform library and compiles on Linux, Mac OSX and Visual
-Studio 2015.
+Studio 2017.
On `Windows <https://wiki.python.org/moin/WindowsCompilers>`_ , only Python
≥ 3.5 are available because of the required Visual Studio version.
diff --git a/src/python/doc/persistence_graphical_tools_sum.inc b/src/python/doc/persistence_graphical_tools_sum.inc
index 0cdf8072..ef376802 100644
--- a/src/python/doc/persistence_graphical_tools_sum.inc
+++ b/src/python/doc/persistence_graphical_tools_sum.inc
@@ -2,11 +2,11 @@
:widths: 30 50 20
+-----------------------------------------------------------------+-----------------------------------------------------------------------+-----------------------------------------------+
- | .. figure:: | These graphical tools comes on top of persistence results and allows | :Author: Vincent Rouvreau |
- | img/graphical_tools_representation.png | the user to build easily persistence barcode, diagram or density. | |
+ | .. figure:: | These graphical tools comes on top of persistence results and allows | :Author: Vincent Rouvreau, Theo Lacombe |
+ | img/graphical_tools_representation.png | the user to display easily persistence barcode, diagram or density. | |
| | | :Introduced in: GUDHI 2.0.0 |
- | | | |
- | | | :Copyright: MIT |
+ | | Note that these functions return the matplotlib axis, allowing | |
+ | | for further modifications (title, aspect, etc.) | :Copyright: MIT |
| | | |
| | | :Requires: matplotlib, numpy and scipy |
+-----------------------------------------------------------------+-----------------------------------------------------------------------+-----------------------------------------------+
diff --git a/src/python/doc/persistence_graphical_tools_user.rst b/src/python/doc/persistence_graphical_tools_user.rst
index 80002db6..91e52703 100644
--- a/src/python/doc/persistence_graphical_tools_user.rst
+++ b/src/python/doc/persistence_graphical_tools_user.rst
@@ -20,7 +20,7 @@ This function can display the persistence result as a barcode:
.. plot::
:include-source:
- import matplotlib.pyplot as plot
+ import matplotlib.pyplot as plt
import gudhi
off_file = gudhi.__root_source_dir__ + '/data/points/tore3D_300.off'
@@ -31,7 +31,7 @@ This function can display the persistence result as a barcode:
diag = simplex_tree.persistence(min_persistence=0.4)
gudhi.plot_persistence_barcode(diag)
- plot.show()
+ plt.show()
Show persistence as a diagram
-----------------------------
@@ -44,15 +44,31 @@ This function can display the persistence result as a diagram:
.. plot::
:include-source:
- import matplotlib.pyplot as plot
+ import matplotlib.pyplot as plt
import gudhi
# rips_on_tore3D_1307.pers obtained from write_persistence_diagram method
persistence_file=gudhi.__root_source_dir__ + \
'/data/persistence_diagram/rips_on_tore3D_1307.pers'
- gudhi.plot_persistence_diagram(persistence_file=persistence_file,
+ ax = gudhi.plot_persistence_diagram(persistence_file=persistence_file,
legend=True)
- plot.show()
+ # We can modify the title, aspect, etc.
+ ax.set_title("Persistence diagram of a torus")
+ ax.set_aspect("equal") # forces to be square shaped
+ plt.show()
+
+Note that (as barcode and density) it can also take a simple `np.array`
+of shape (N x 2) encoding a persistence diagram (in a given dimension).
+
+.. plot::
+ :include-source:
+
+ import matplotlib.pyplot as plt
+ import gudhi
+ import numpy as np
+ d = np.array([[0, 1], [1, 2], [1, np.inf]])
+ gudhi.plot_persistence_diagram(d)
+ plt.show()
Persistence density
-------------------
@@ -65,7 +81,7 @@ If you want more information on a specific dimension, for instance:
.. plot::
:include-source:
- import matplotlib.pyplot as plot
+ import matplotlib.pyplot as plt
import gudhi
# rips_on_tore3D_1307.pers obtained from write_persistence_diagram method
persistence_file=gudhi.__root_source_dir__ + \
@@ -75,9 +91,9 @@ If you want more information on a specific dimension, for instance:
only_this_dim=1)
pers_diag = [(1, elt) for elt in birth_death]
# Use subplots to display diagram and density side by side
- fig, axes = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
+ fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
gudhi.plot_persistence_diagram(persistence=pers_diag,
axes=axes[0])
gudhi.plot_persistence_density(persistence=pers_diag,
dimension=1, legend=True, axes=axes[1])
- plot.show()
+ plt.show()
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst
index 32999a0c..94b454e2 100644
--- a/src/python/doc/wasserstein_distance_user.rst
+++ b/src/python/doc/wasserstein_distance_user.rst
@@ -9,17 +9,26 @@ Definition
.. include:: wasserstein_distance_sum.inc
-This implementation is based on ideas from "Large Scale Computation of Means and Cluster for Persistence Diagrams via Optimal Transport".
+Functions
+---------
+This implementation uses the Python Optimal Transport library and is based on
+ideas from "Large Scale Computation of Means and Cluster for Persistence
+Diagrams via Optimal Transport" :cite:`10.5555/3327546.3327645`.
-Function
---------
.. autofunction:: gudhi.wasserstein.wasserstein_distance
+This other implementation comes from `Hera
+<https://bitbucket.org/grey_narn/hera/src/master/>`_ (BSD-3-Clause) which is
+based on "Geometry Helps to Compare Persistence Diagrams"
+:cite:`Kerber:2017:GHC:3047249.3064175` by Michael Kerber, Dmitriy
+Morozov, and Arnur Nigmetov.
+
+.. autofunction:: gudhi.hera.wasserstein_distance
Basic example
-------------
-This example computes the 1-Wasserstein distance from 2 persistence diagrams with euclidean ground metric.
+This example computes the 1-Wasserstein distance from 2 persistence diagrams with Euclidean ground metric.
Note that persistence diagrams must be submitted as (n x 2) numpy arrays and must not contain inf values.
.. testcode::
diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc
new file mode 100644
index 00000000..0d562b4c
--- /dev/null
+++ b/src/python/gudhi/hera.cc
@@ -0,0 +1,71 @@
+/* This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ * See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ * Author(s): Marc Glisse
+ *
+ * Copyright (C) 2020 Inria
+ *
+ * Modification(s):
+ * - YYYY/MM Author: Description of the modification
+ */
+
+#include <pybind11/pybind11.h>
+#include <pybind11/numpy.h>
+
+#include <boost/range/iterator_range.hpp>
+
+#include <wasserstein.h> // Hera
+
+#include <array>
+
+namespace py = pybind11;
+typedef py::array_t<double, py::array::c_style | py::array::forcecast> Dgm;
+
+double wasserstein_distance(
+ Dgm d1, Dgm d2,
+ double wasserstein_power, double internal_p,
+ double delta)
+{
+ py::buffer_info buf1 = d1.request();
+ py::buffer_info buf2 = d2.request();
+ // shape (n,2) or (0) for empty
+ if((buf1.ndim!=2 || buf1.shape[1]!=2) && (buf1.ndim!=1 || buf1.shape[0]!=0))
+ throw std::runtime_error("Diagram 1 must be an array of size n x 2");
+ if((buf2.ndim!=2 || buf2.shape[1]!=2) && (buf2.ndim!=1 || buf2.shape[0]!=0))
+ throw std::runtime_error("Diagram 2 must be an array of size n x 2");
+ typedef std::array<double, 2> Point;
+ auto p1 = (Point*)buf1.ptr;
+ auto p2 = (Point*)buf2.ptr;
+ auto diag1 = boost::make_iterator_range(p1, p1+buf1.shape[0]);
+ auto diag2 = boost::make_iterator_range(p2, p2+buf2.shape[0]);
+
+ hera::AuctionParams<double> params;
+ params.wasserstein_power = wasserstein_power;
+ // hera encodes infinity as -1...
+ if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>();
+ params.internal_p = internal_p;
+ params.delta = delta;
+ // The extra parameters are purposedly not exposed for now.
+ return hera::wasserstein_dist(diag1, diag2, params);
+}
+
+PYBIND11_MODULE(hera, m) {
+ m.def("wasserstein_distance", &wasserstein_distance,
+ py::arg("X"), py::arg("Y"),
+ py::arg("order") = 1,
+ py::arg("internal_p") = std::numeric_limits<double>::infinity(),
+ py::arg("delta") = .01,
+ R"pbdoc(
+ Compute the Wasserstein distance between two diagrams.
+ Points at infinity are supported.
+
+ Parameters:
+ X (n x 2 numpy array): First diagram
+ Y (n x 2 numpy array): Second diagram
+ order (float): Wasserstein exponent W_q
+ internal_p (float): Internal Minkowski norm L^p in R^2
+ delta (float): Relative error 1+delta
+
+ Returns:
+ float: Approximate Wasserstein distance W_q(X,Y)
+ )pbdoc");
+}
diff --git a/src/python/gudhi/nerve_gic.pyx b/src/python/gudhi/nerve_gic.pyx
index 382e71c5..45cc8eba 100644
--- a/src/python/gudhi/nerve_gic.pyx
+++ b/src/python/gudhi/nerve_gic.pyx
@@ -187,7 +187,7 @@ cdef class CoverComplex:
def set_automatic_resolution(self):
"""Computes the optimal length of intervals (i.e. the smallest interval
- length avoiding discretization artifacts—see :cite:`Carriere17c`) for a
+ length avoiding discretization artifacts - see :cite:`Carriere17c`) for a
functional cover.
:rtype: double
@@ -288,7 +288,7 @@ cdef class CoverComplex:
def set_graph_from_automatic_rips(self, N=100):
"""Creates a graph G from a Rips complex whose threshold value is
- automatically tuned with subsampling—see.
+ automatically tuned with subsampling - see :cite:`Carriere17c`.
:param N: Number of subsampling iteration (the default reasonable value
is 100, but there is no guarantee on how to choose it).
diff --git a/src/python/gudhi/persistence_graphical_tools.py b/src/python/gudhi/persistence_graphical_tools.py
index 246280de..cc3db467 100644
--- a/src/python/gudhi/persistence_graphical_tools.py
+++ b/src/python/gudhi/persistence_graphical_tools.py
@@ -5,6 +5,7 @@
# Copyright (C) 2016 Inria
#
# Modification(s):
+# - 2020/02 Theo Lacombe: Added more options for improved rendering and more flexibility.
# - YYYY/MM Author: Description of the modification
from os import path
@@ -14,7 +15,7 @@ import numpy as np
from gudhi.reader_utils import read_persistence_intervals_in_dimension
from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension
-__author__ = "Vincent Rouvreau, Bertrand Michel"
+__author__ = "Vincent Rouvreau, Bertrand Michel, Theo Lacombe"
__copyright__ = "Copyright (C) 2016 Inria"
__license__ = "MIT"
@@ -43,6 +44,19 @@ def __min_birth_max_death(persistence, band=0.0):
max_death += band
return (min_birth, max_death)
+
+def _array_handler(a):
+ '''
+ :param a: if array, assumes it is a (n x 2) np.array and return a
+ persistence-compatible list (padding with 0), so that the
+ plot can be performed seamlessly.
+ '''
+ if isinstance(a[0][1], np.float64) or isinstance(a[0][1], float):
+ return [[0, x] for x in a]
+ else:
+ return a
+
+
def plot_persistence_barcode(
persistence=[],
persistence_file="",
@@ -52,13 +66,16 @@ def plot_persistence_barcode(
inf_delta=0.1,
legend=False,
colormap=None,
- axes=None
+ axes=None,
+ fontsize=16,
):
"""This function plots the persistence bar code from persistence values list
+ , a np.array of shape (N x 2) (representing a diagram
+ in a single homology dimension),
or from a :doc:`persistence file <fileformats>`.
- :param persistence: Persistence intervals values list grouped by dimension.
- :type persistence: list of tuples(dimension, tuple(birth, death)).
+ :param persistence: Persistence intervals values list. Can be grouped by dimension or not.
+ :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
:param persistence_file: A :doc:`persistence file <fileformats>` style name
(reset persistence if both are set).
:type persistence_file: string
@@ -81,11 +98,19 @@ def plot_persistence_barcode(
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
a new set of axes.
:type axes: `matplotlib.axes.Axes`
+ :param fontsize: Fontsize to use in axis.
+ :type fontsize: int
:returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
+ from matplotlib import rc
+ plt.rc('text', usetex=True)
+ plt.rc('font', family='serif')
+
+
+ persistence = _array_handler(persistence)
if persistence_file != "":
if path.isfile(persistence_file):
@@ -163,7 +188,7 @@ def plot_persistence_barcode(
loc="lower right",
)
- axes.set_title("Persistence barcode")
+ axes.set_title("Persistence barcode", fontsize=fontsize)
# Ends plot on infinity value and starts a little bit before min_birth
axes.axis([axis_start, infinity, 0, ind])
@@ -183,13 +208,16 @@ def plot_persistence_diagram(
inf_delta=0.1,
legend=False,
colormap=None,
- axes=None
+ axes=None,
+ fontsize=16,
+ greyblock=True
):
"""This function plots the persistence diagram from persistence values
- list or from a :doc:`persistence file <fileformats>`.
+ list, a np.array of shape (N x 2) representing a diagram in a single
+ homology dimension, or from a :doc:`persistence file <fileformats>`.
- :param persistence: Persistence intervals values list grouped by dimension.
- :type persistence: list of tuples(dimension, tuple(birth, death)).
+ :param persistence: Persistence intervals values list. Can be grouped by dimension or not.
+ :type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
:param persistence_file: A :doc:`persistence file <fileformats>` style name
(reset persistence if both are set).
:type persistence_file: string
@@ -214,11 +242,20 @@ def plot_persistence_diagram(
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
a new set of axes.
:type axes: `matplotlib.axes.Axes`
+ :param fontsize: Fontsize to use in axis.
+ :type fontsize: int
+ :param greyblock: if we want to plot a grey patch on the lower half plane for nicer rendering. Default True.
+ :type greyblock: boolean
:returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
+ from matplotlib import rc
+ plt.rc('text', usetex=True)
+ plt.rc('font', family='serif')
+
+ persistence = _array_handler(persistence)
if persistence_file != "":
if path.isfile(persistence_file):
@@ -256,19 +293,18 @@ def plot_persistence_diagram(
# Replace infinity values with max_death + delta for diagram to be more
# readable
infinity = max_death + delta
+ axis_end = max_death + delta / 2
axis_start = min_birth - delta
- # line display of equation : birth = death
- x = np.linspace(axis_start, infinity, 1000)
- # infinity line and text
- axes.plot(x, x, color="k", linewidth=1.0)
- axes.plot(x, [infinity] * len(x), linewidth=1.0, color="k", alpha=alpha)
- axes.text(axis_start, infinity, r"$\infty$", color="k", alpha=alpha)
# bootstrap band
if band > 0.0:
+ x = np.linspace(axis_start, infinity, 1000)
axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red")
-
+ # lower diag patch
+ if greyblock:
+ axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], fill=True, color='lightgrey'))
# Draw points in loop
+ pts_at_infty = False # Records presence of pts at infty
for interval in reversed(persistence):
if float(interval[1][1]) != float("inf"):
# Finite death case
@@ -279,10 +315,23 @@ def plot_persistence_diagram(
color=colormap[interval[0]],
)
else:
+ pts_at_infty = True
# Infinite death case for diagram to be nicer
axes.scatter(
interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]
)
+ if pts_at_infty:
+ # infinity line and text
+ axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
+ axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha)
+ # Infinity label
+ yt = axes.get_yticks()
+ yt = yt[np.where(yt < axis_end)] # to avoid ploting ticklabel higher than infinity
+ yt = np.append(yt, infinity)
+ ytl = ["%.3f" % e for e in yt] # to avoid float precision error
+ ytl[-1] = r'$+\infty$'
+ axes.set_yticks(yt)
+ axes.set_yticklabels(ytl)
if legend:
dimensions = list(set(item[0] for item in persistence))
@@ -293,11 +342,11 @@ def plot_persistence_diagram(
]
)
- axes.set_xlabel("Birth")
- axes.set_ylabel("Death")
+ axes.set_xlabel("Birth", fontsize=fontsize)
+ axes.set_ylabel("Death", fontsize=fontsize)
+ axes.set_title("Persistence diagram", fontsize=fontsize)
# Ends plot on infinity value and starts a little bit before min_birth
- axes.axis([axis_start, infinity, axis_start, infinity + delta])
- axes.set_title("Persistence diagram")
+ axes.axis([axis_start, axis_end, axis_start, infinity + delta/2])
return axes
except ImportError:
@@ -313,16 +362,22 @@ def plot_persistence_density(
dimension=None,
cmap=None,
legend=False,
- axes=None
+ axes=None,
+ fontsize=16,
+ greyblock=False
):
"""This function plots the persistence density from persistence
- values list or from a :doc:`persistence file <fileformats>`. Be
+ values list, np.array of shape (N x 2) representing a diagram
+ in a single homology dimension,
+ or from a :doc:`persistence file <fileformats>`. Be
aware that this function does not distinguish the dimension, it is
up to you to select the required one. This function also does not handle
degenerate data set (scipy correlation matrix inversion can fail).
- :param persistence: Persistence intervals values list grouped by dimension.
- :type persistence: list of tuples(dimension, tuple(birth, death)).
+ :param persistence: Persistence intervals values list.
+ Can be grouped by dimension or not.
+ :type persistence: an array of (dimension, array of (birth, death))
+ or an array of (birth, death).
:param persistence_file: A :doc:`persistence file <fileformats>`
style name (reset persistence if both are set).
:type persistence_file: string
@@ -355,11 +410,22 @@ def plot_persistence_density(
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
a new set of axes.
:type axes: `matplotlib.axes.Axes`
+ :param fontsize: Fontsize to use in axis.
+ :type fontsize: int
+ :param greyblock: if we want to plot a grey patch on the lower half plane
+ for nicer rendering. Default False.
+ :type greyblock: boolean
:returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
try:
import matplotlib.pyplot as plt
+ import matplotlib.patches as mpatches
from scipy.stats import kde
+ from matplotlib import rc
+ plt.rc('text', usetex=True)
+ plt.rc('font', family='serif')
+
+ persistence = _array_handler(persistence)
if persistence_file != "":
if dimension is None:
@@ -418,12 +484,16 @@ def plot_persistence_density(
# Make the plot
img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap)
+ if greyblock:
+ axes.add_patch(mpatches.Polygon([[birth.min(), birth.min()], [death.max(), birth.min()], [death.max(), death.max()]], fill=True, color='lightgrey'))
+
if legend:
plt.colorbar(img, ax=axes)
- axes.set_xlabel("Birth")
- axes.set_ylabel("Death")
- axes.set_title("Persistence density")
+ axes.set_xlabel("Birth", fontsize=fontsize)
+ axes.set_ylabel("Death", fontsize=fontsize)
+ axes.set_title("Persistence density", fontsize=fontsize)
+
return axes
except ImportError:
diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py
index db5ddff2..13102094 100644
--- a/src/python/gudhi/wasserstein.py
+++ b/src/python/gudhi/wasserstein.py
@@ -27,8 +27,8 @@ def _build_dist_matrix(X, Y, order=2., internal_p=2.):
'''
:param X: (n x 2) numpy.array encoding the (points of the) first diagram.
:param Y: (m x 2) numpy.array encoding the second diagram.
- :param internal_p: Ground metric (i.e. norm l_p).
:param order: exponent for the Wasserstein metric.
+ :param internal_p: Ground metric (i.e. norm L^p).
:returns: (n+1) x (m+1) np.array encoding the cost matrix C.
For 1 <= i <= n, 1 <= j <= m, C[i,j] encodes the distance between X[i] and Y[j], while C[i, m+1] (resp. C[n+1, j]) encodes the distance (to the p) between X[i] (resp Y[j]) and its orthogonal proj onto the diagonal.
note also that C[n+1, m+1] = 0 (it costs nothing to move from the diagonal to the diagonal).
@@ -54,8 +54,8 @@ def _build_dist_matrix(X, Y, order=2., internal_p=2.):
def _perstot(X, order, internal_p):
'''
:param X: (n x 2) numpy.array (points of a given diagram).
- :param internal_p: Ground metric on the (upper-half) plane (i.e. norm l_p in R^2); Default value is 2 (Euclidean norm).
:param order: exponent for Wasserstein. Default value is 2.
+ :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm).
:returns: float, the total persistence of the diagram (that is, its distance to the empty diagram).
'''
Xdiag = _proj_on_diag(X)
@@ -66,8 +66,8 @@ def wasserstein_distance(X, Y, order=2., internal_p=2.):
'''
:param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate).
:param Y: (m x 2) numpy.array encoding the second diagram.
- :param internal_p: Ground metric on the (upper-half) plane (i.e. norm l_p in R^2); Default value is 2 (euclidean norm).
:param order: exponent for Wasserstein; Default value is 2.
+ :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm).
:returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with respect to the internal_p-norm as ground metric.
:rtype: float
'''
diff --git a/src/python/setup.py.in b/src/python/setup.py.in
index 9c2124f4..f968bd59 100644
--- a/src/python/setup.py.in
+++ b/src/python/setup.py.in
@@ -8,10 +8,11 @@
- YYYY/MM Author: Description of the modification
"""
-from setuptools import setup, Extension
+from setuptools import setup, Extension, find_packages
from Cython.Build import cythonize
from numpy import get_include as numpy_get_include
import sys
+import pybind11
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2016 Inria"
@@ -42,14 +43,26 @@ for module in modules:
runtime_library_dirs=runtime_library_dirs,
cython_directives = {'language_level': str(sys.version_info[0])},))
+ext_modules = cythonize(ext_modules)
+
+ext_modules.append(Extension(
+ 'gudhi.hera',
+ sources = [source_dir + 'hera.cc'],
+ language = 'c++',
+ include_dirs = include_dirs +
+ ['@HERA_WASSERSTEIN_INCLUDE_DIR@',
+ pybind11.get_include(False), pybind11.get_include(True)],
+ extra_compile_args=extra_compile_args + [@GUDHI_PYBIND11_EXTRA_COMPILE_ARGS@],
+ ))
+
setup(
name = 'gudhi',
- packages=["gudhi",],
+ packages=find_packages(), # find_namespace_packages(include=["gudhi*"])
author='GUDHI Editorial Board',
author_email='gudhi-contact@lists.gforge.inria.fr',
version='@GUDHI_VERSION@',
url='http://gudhi.gforge.inria.fr/',
- ext_modules = cythonize(ext_modules),
+ ext_modules = ext_modules,
install_requires = ['cython','numpy >= 1.9',],
- setup_requires = ['numpy >= 1.9',],
+ setup_requires = ['numpy >= 1.9','pybind11',],
)
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 43dda77e..6a6b217b 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -1,6 +1,6 @@
""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
- Author(s): Theo Lacombe
+ Author(s): Theo Lacombe, Marc Glisse
Copyright (C) 2019 Inria
@@ -8,41 +8,66 @@
- YYYY/MM Author: Description of the modification
"""
-from gudhi.wasserstein import wasserstein_distance
+from gudhi.wasserstein import wasserstein_distance as pot
+from gudhi.hera import wasserstein_distance as hera
import numpy as np
+import pytest
__author__ = "Theo Lacombe"
__copyright__ = "Copyright (C) 2019 Inria"
__license__ = "MIT"
-
-def test_basic_wasserstein():
+def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True):
diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])
diag3 = np.array([[0, 2], [4, 6]])
diag4 = np.array([[0, 3], [4, 8]])
- emptydiag = np.array([[]])
+ emptydiag = np.array([])
+
+ # We just need to handle positive numbers here
+ def approx(x):
+ return pytest.approx(x, rel=delta)
assert wasserstein_distance(emptydiag, emptydiag, internal_p=2., order=1.) == 0.
assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=1.) == 0.
assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=2.) == 0.
assert wasserstein_distance(emptydiag, emptydiag, internal_p=2., order=2.) == 0.
- assert wasserstein_distance(diag3, emptydiag, internal_p=np.inf, order=1.) == 2.
- assert wasserstein_distance(diag3, emptydiag, internal_p=1., order=1.) == 4.
+ assert wasserstein_distance(diag3, emptydiag, internal_p=np.inf, order=1.) == approx(2.)
+ assert wasserstein_distance(diag3, emptydiag, internal_p=1., order=1.) == approx(4.)
+
+ assert wasserstein_distance(diag4, emptydiag, internal_p=1., order=2.) == approx(5.) # thank you Pythagorician triplets
+ assert wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.) == approx(2.5)
+ assert wasserstein_distance(diag4, emptydiag, internal_p=2., order=2.) == approx(3.5355339059327378)
+
+ assert wasserstein_distance(diag1, diag2, internal_p=2., order=1.) == approx(1.4453593023967701)
+ assert wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74) == approx(0.9772734057168739)
+
+ assert wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863) == approx(3.141592214572228)
+
+ assert wasserstein_distance(diag3, diag4, internal_p=1., order=1.) == approx(3.)
+ assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.) == approx(3.) # no diag matching here
+ assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.) == approx(np.sqrt(5))
+ assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == approx(np.sqrt(5))
+ assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == approx(np.sqrt(5))
+
+ if(not test_infinity):
+ return
- assert wasserstein_distance(diag4, emptydiag, internal_p=1., order=2.) == 5. # thank you Pythagorician triplets
- assert wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.) == 2.5
- assert wasserstein_distance(diag4, emptydiag, internal_p=2., order=2.) == 3.5355339059327378
+ diag5 = np.array([[0, 3], [4, np.inf]])
+ diag6 = np.array([[7, 8], [4, 6], [3, np.inf]])
- assert wasserstein_distance(diag1, diag2, internal_p=2., order=1.) == 1.4453593023967701
- assert wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74) == 0.9772734057168739
+ assert wasserstein_distance(diag4, diag5) == np.inf
+ assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.)
- assert wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863) == 3.141592214572228
+def hera_wrap(delta):
+ def fun(*kargs,**kwargs):
+ return hera(*kargs,**kwargs,delta=delta)
+ return fun
- assert wasserstein_distance(diag3, diag4, internal_p=1., order=1.) == 3.
- assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.) == 3. # no diag matching here
- assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.) == np.sqrt(5)
- assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == np.sqrt(5)
- assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == np.sqrt(5)
+def test_wasserstein_distance_pot():
+ _basic_wasserstein(pot, 1e-15, test_infinity=False)
+def test_wasserstein_distance_hera():
+ _basic_wasserstein(hera_wrap(1e-12), 1e-12)
+ _basic_wasserstein(hera_wrap(.1), .1)