summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-14 19:49:49 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-14 19:49:49 +0200
commit3f1e4bf5f139afe887ae501f20c5d3f40b5a6f79 (patch)
tree936aa887b1e7a0e23854262f17a5d21cf2f604c1
parent9518287cfa2a62948ede2e7d17d5c9f29092e0f4 (diff)
parent6d02ca0e077cc9750275abdfc024429cec0ba5a5 (diff)
Merge remote-tracking branch 'origin/master' into dtm
-rw-r--r--.github/next_release.md1
-rw-r--r--.github/test-requirements.txt6
-rw-r--r--Dockerfile_for_circleci_image3
-rw-r--r--biblio/bibliography.bib46
-rw-r--r--src/Simplex_tree/include/gudhi/Simplex_tree.h124
-rw-r--r--src/python/CMakeLists.txt3
-rw-r--r--src/python/doc/alpha_complex_user.rst4
-rw-r--r--src/python/doc/bottleneck_distance_user.rst7
-rw-r--r--src/python/doc/cubical_complex_user.rst4
-rw-r--r--src/python/doc/img/barycenter.pngbin0 -> 12433 bytes
-rw-r--r--src/python/doc/index.rst3
-rw-r--r--src/python/doc/installation.rst8
-rw-r--r--src/python/doc/nerve_gic_complex_ref.rst7
-rw-r--r--src/python/doc/nerve_gic_complex_user.rst7
-rw-r--r--src/python/doc/persistent_cohomology_user.rst4
-rw-r--r--src/python/doc/rips_complex_user.rst7
-rw-r--r--src/python/doc/simplex_tree_user.rst7
-rw-r--r--src/python/doc/tangential_complex_user.rst4
-rw-r--r--src/python/doc/wasserstein_distance_sum.inc10
-rw-r--r--src/python/doc/wasserstein_distance_user.rst103
-rw-r--r--src/python/doc/witness_complex_user.rst4
-rw-r--r--src/python/gudhi/simplex_tree.pxd2
-rw-r--r--src/python/gudhi/simplex_tree.pyx51
-rw-r--r--src/python/gudhi/wasserstein/__init__.py1
-rw-r--r--src/python/gudhi/wasserstein/barycenter.py159
-rw-r--r--src/python/gudhi/wasserstein/wasserstein.py (renamed from src/python/gudhi/wasserstein.py)15
-rw-r--r--src/python/include/Simplex_tree_interface.h40
-rwxr-xr-xsrc/python/test/test_simplex_tree.py82
-rwxr-xr-xsrc/python/test/test_wasserstein_barycenter.py46
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py2
30 files changed, 711 insertions, 49 deletions
diff --git a/.github/next_release.md b/.github/next_release.md
index 3166b0a8..83b98a1c 100644
--- a/.github/next_release.md
+++ b/.github/next_release.md
@@ -9,6 +9,7 @@ Below is a list of changes made since GUDHI 3.1.1:
- [Wassertein distance](https://gudhi.inria.fr/python/latest/wasserstein_distance_user.html)
- An another implementation comes from Hera (BSD-3-Clause) which is based on [Geometry Helps to Compare Persistence Diagrams](http://doi.acm.org/10.1145/3064175) by Michael Kerber, Dmitriy Morozov, and Arnur Nigmetov.
+ - `gudhi.wasserstein.wasserstein_distance` has now an option to return the optimal matching that achieves the distance between the two diagrams.
- [Module](link)
- ...
diff --git a/.github/test-requirements.txt b/.github/test-requirements.txt
index 18882792..fb1df134 100644
--- a/.github/test-requirements.txt
+++ b/.github/test-requirements.txt
@@ -6,4 +6,8 @@ matplotlib
scipy
scikit-learn
POT
-tensorflow \ No newline at end of file
+tensorflow
+torch
+pykeops
+hnswlib
+eagerpy
diff --git a/Dockerfile_for_circleci_image b/Dockerfile_for_circleci_image
index 1eededb5..c2e8a8f5 100644
--- a/Dockerfile_for_circleci_image
+++ b/Dockerfile_for_circleci_image
@@ -43,6 +43,7 @@ RUN apt-get install -y make \
python3 \
python3-pip \
python3-tk \
+ python3-grpcio \
libfreetype6-dev \
pkg-config \
curl
@@ -51,7 +52,7 @@ ADD .github/build-requirements.txt /
ADD .github/test-requirements.txt /
RUN pip3 install -r build-requirements.txt
-RUN pip3 install -r test-requirements.txt
+RUN pip3 --no-cache-dir install -r test-requirements.txt
# apt clean up
RUN apt autoremove && rm -rf /var/lib/apt/lists/*
diff --git a/biblio/bibliography.bib b/biblio/bibliography.bib
index 056ccd72..2bfc28e3 100644
--- a/biblio/bibliography.bib
+++ b/biblio/bibliography.bib
@@ -7,11 +7,13 @@
}
@article{Carriere17c,
- author = {Carri\`ere, Mathieu and Michel, Bertrand and Oudot, Steve},
- title = {{Statistical Analysis and Parameter Selection for Mapper}},
- journal = {CoRR},
- volume = {abs/1706.00204},
- year = {2017}
+author = {Carri{\`{e}}re, Mathieu and Michel, Bertrand and Oudot, Steve},
+journal = {Journal of Machine Learning Research},
+pages = {1--39},
+publisher = {JMLR.org},
+title = {{Statistical analysis and parameter selection for Mapper}},
+volume = {19},
+year = {2018}
}
@inproceedings{Dey13,
@@ -23,11 +25,14 @@
}
@article{Carriere16,
- title={{Structure and Stability of the 1-Dimensional Mapper}},
- author={Carri\`ere, Mathieu and Oudot, Steve},
- journal={CoRR},
- volume= {abs/1511.05823},
- year={2015}
+author = {Carri{\`{e}}re, Mathieu and Oudot, Steve},
+journal = {Foundations of Computational Mathematics},
+number = {6},
+pages = {1333--1396},
+publisher = {Springer-Verlag},
+title = {{Structure and stability of the one-dimensional Mapper}},
+volume = {18},
+year = {2017}
}
@inproceedings{zigzag_reflection,
@@ -36,6 +41,17 @@
year = {2014 $\ \ \ \ \ \ \ \ \ \ \ $ \emph{In Preparation}},
}
+@article{Cohen-Steiner2009,
+author = {Cohen-Steiner, David and Edelsbrunner, Herbert and Harer, John},
+journal = {Foundations of Computational Mathematics},
+number = {1},
+pages = {79--103},
+publisher = {Springer-Verlag},
+title = {{Extending persistence using Poincar{\'{e}} and Lefschetz duality}},
+volume = {9},
+year = {2009}
+}
+
@misc{gudhi_stpcoh,
author = {Cl\'ement Maria},
title = "\textsc{Gudhi}, Simplex Tree and Persistent Cohomology Packages",
@@ -1219,3 +1235,13 @@ url = "https://doi.org/10.1214/11-EJS606",
volume = "5",
year = "2011"
}
+@article{turner2014frechet,
+ title={Fr{\'e}chet means for distributions of persistence diagrams},
+ author={Turner, Katharine and Mileyko, Yuriy and Mukherjee, Sayan and Harer, John},
+ journal={Discrete \& Computational Geometry},
+ volume={52},
+ number={1},
+ pages={44--70},
+ year={2014},
+ publisher={Springer}
+}
diff --git a/src/Simplex_tree/include/gudhi/Simplex_tree.h b/src/Simplex_tree/include/gudhi/Simplex_tree.h
index 430d1ac4..b455ae31 100644
--- a/src/Simplex_tree/include/gudhi/Simplex_tree.h
+++ b/src/Simplex_tree/include/gudhi/Simplex_tree.h
@@ -42,6 +42,20 @@
namespace Gudhi {
+/**
+ * \class Extended_simplex_type Simplex_tree.h gudhi/Simplex_tree.h
+ * \brief Extended simplex type data structure for representing the type of simplices in an extended filtration.
+ *
+ * \details The extended simplex type can be either UP (which means
+ * that the simplex was present originally, and is thus part of the ascending extended filtration), DOWN (which means
+ * that the simplex is the cone of an original simplex, and is thus part of the descending extended filtration) or
+ * EXTRA (which means the simplex is the cone point).
+ *
+ * Details may be found in \cite Cohen-Steiner2009 and section 2.2 in \cite Carriere16.
+ *
+ */
+enum class Extended_simplex_type {UP, DOWN, EXTRA};
+
struct Simplex_tree_options_full_featured;
/**
@@ -87,6 +101,8 @@ class Simplex_tree {
/* \brief Set of nodes sharing a same parent in the simplex tree. */
typedef Simplex_tree_siblings<Simplex_tree, Dictionary> Siblings;
+
+
struct Key_simplex_base_real {
Key_simplex_base_real() : key_(-1) {}
void assign_key(Simplex_key k) { key_ = k; }
@@ -100,6 +116,12 @@ class Simplex_tree {
void assign_key(Simplex_key);
Simplex_key key() const;
};
+ struct Extended_filtration_data {
+ Filtration_value minval;
+ Filtration_value maxval;
+ Extended_filtration_data(){}
+ Extended_filtration_data(Filtration_value vmin, Filtration_value vmax): minval(vmin), maxval(vmax) {}
+ };
typedef typename std::conditional<Options::store_key, Key_simplex_base_real, Key_simplex_base_dummy>::type
Key_simplex_base;
@@ -1471,6 +1493,108 @@ class Simplex_tree {
}
}
+ /** \brief Retrieve the original filtration value for a given simplex in the Simplex_tree. Since the
+ * computation of extended persistence requires modifying the filtration values, this function can be used
+ * to recover the original values. Moreover, computing extended persistence requires adding new simplices
+ * in the Simplex_tree. Hence, this function also outputs the type of each simplex. It can be either UP (which means
+ * that the simplex was present originally, and is thus part of the ascending extended filtration), DOWN (which means
+ * that the simplex is the cone of an original simplex, and is thus part of the descending extended filtration) or
+ * EXTRA (which means the simplex is the cone point). See the definition of Extended_simplex_type. Note that if the simplex type is DOWN, the original filtration value
+ * is set to be the original filtration value of the corresponding (not coned) original simplex.
+ * \pre This function should be called only if `extend_filtration()` has been called first!
+ * \post The output filtration value is supposed to be the same, but might be a little different, than the
+ * original filtration value, due to the internal transformation (scaling to [-2,-1]) that is
+ * performed on the original filtration values during the computation of extended persistence.
+ * @param[in] f Filtration value of the simplex in the extended (i.e., modified) filtration.
+ * @param[in] efd Structure containing the minimum and maximum values of the original filtration. This the output of `extend_filtration()`.
+ * @return A pair containing the original filtration value of the simplex as well as the simplex type.
+ */
+ std::pair<Filtration_value, Extended_simplex_type> decode_extended_filtration(Filtration_value f, const Extended_filtration_data& efd){
+ std::pair<Filtration_value, Extended_simplex_type> p;
+ Filtration_value minval = efd.minval;
+ Filtration_value maxval = efd.maxval;
+ if (f >= -2 && f <= -1){
+ p.first = minval + (maxval-minval)*(f + 2); p.second = Extended_simplex_type::UP;
+ }
+ else if (f >= 1 && f <= 2){
+ p.first = minval - (maxval-minval)*(f - 2); p.second = Extended_simplex_type::DOWN;
+ }
+ else{
+ p.first = std::numeric_limits<Filtration_value>::quiet_NaN(); p.second = Extended_simplex_type::EXTRA;
+ }
+ return p;
+ };
+
+ /** \brief Extend filtration for computing extended persistence.
+ * This function only uses the filtration values at the 0-dimensional simplices,
+ * and computes the extended persistence diagram induced by the lower-star filtration
+ * computed with these values.
+ * \post Note that after calling this function, the filtration
+ * values are actually modified. The function `decode_extended_filtration()`
+ * retrieves the original values and outputs the extended simplex type.
+ * \pre Note that this code creates an extra vertex internally, so you should make sure that
+ * the Simplex tree does not contain a vertex with the largest Vertex_handle.
+ * @return A data structure containing the maximum and minimum values of the original filtration.
+ * It is meant to be provided as input to `decode_extended_filtration()` in order to retrieve
+ * the original filtration values for each simplex.
+ */
+ Extended_filtration_data extend_filtration() {
+
+ // Compute maximum and minimum of filtration values
+ Vertex_handle maxvert = std::numeric_limits<Vertex_handle>::min();
+ Filtration_value minval = std::numeric_limits<Filtration_value>::infinity();
+ Filtration_value maxval = -std::numeric_limits<Filtration_value>::infinity();
+ for (auto sh = root_.members().begin(); sh != root_.members().end(); ++sh){
+ Filtration_value f = this->filtration(sh);
+ minval = std::min(minval, f);
+ maxval = std::max(maxval, f);
+ maxvert = std::max(sh->first, maxvert);
+ }
+
+ GUDHI_CHECK(maxvert < std::numeric_limits<Vertex_handle>::max(), std::invalid_argument("Simplex_tree contains a vertex with the largest Vertex_handle"));
+ maxvert += 1;
+
+ Simplex_tree st_copy = *this;
+
+ // Add point for coning the simplicial complex
+ this->insert_simplex({maxvert}, -3);
+
+ // For each simplex
+ std::vector<Vertex_handle> vr;
+ for (auto sh_copy : st_copy.complex_simplex_range()){
+
+ // Locate simplex
+ vr.clear();
+ for (auto vh : st_copy.simplex_vertex_range(sh_copy)){
+ vr.push_back(vh);
+ }
+ auto sh = this->find(vr);
+
+ // Create cone on simplex
+ vr.push_back(maxvert);
+ if (this->dimension(sh) == 0){
+ Filtration_value v = this->filtration(sh);
+ Filtration_value scaled_v = (v-minval)/(maxval-minval);
+ // Assign ascending value between -2 and -1 to vertex
+ this->assign_filtration(sh, -2 + scaled_v);
+ // Assign descending value between 1 and 2 to cone on vertex
+ this->insert_simplex(vr, 2 - scaled_v);
+ }
+ else{
+ // Assign value -3 to simplex and cone on simplex
+ this->assign_filtration(sh, -3);
+ this->insert_simplex(vr, -3);
+ }
+ }
+
+ // Automatically assign good values for simplices
+ this->make_filtration_non_decreasing();
+
+ // Return the filtration data
+ Extended_filtration_data efd(minval, maxval);
+ return efd;
+ }
+
/** \brief Returns a vertex of `sh` that has the same filtration value as `sh` if it exists, and `null_vertex()` otherwise.
*
* For a lower-star filtration built with `make_filtration_non_decreasing()`, this is a way to invert the process and find out which vertex had its filtration value propagated to `sh`.
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 99e8b57c..10dcd161 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -229,7 +229,7 @@ if(PYTHONINTERP_FOUND)
# Other .py files
file(COPY "gudhi/persistence_graphical_tools.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/representations" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/")
- file(COPY "gudhi/wasserstein.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
+ file(COPY "gudhi/wasserstein" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/point_cloud" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
add_custom_command(
@@ -402,6 +402,7 @@ if(PYTHONINTERP_FOUND)
# Wasserstein
if(OT_FOUND AND PYBIND11_FOUND)
add_gudhi_py_test(test_wasserstein_distance)
+ add_gudhi_py_test(test_wasserstein_barycenter)
endif()
# Representations
diff --git a/src/python/doc/alpha_complex_user.rst b/src/python/doc/alpha_complex_user.rst
index 60319e84..265a82d2 100644
--- a/src/python/doc/alpha_complex_user.rst
+++ b/src/python/doc/alpha_complex_user.rst
@@ -204,8 +204,8 @@ the program output is:
[3, 6] -> 30.25
CGAL citations
-==============
+--------------
.. bibliography:: ../../biblio/how_to_cite_cgal.bib
- :filter: docnames
+ :filter: docname in docnames
:style: unsrt
diff --git a/src/python/doc/bottleneck_distance_user.rst b/src/python/doc/bottleneck_distance_user.rst
index 9435c7f1..206fcb63 100644
--- a/src/python/doc/bottleneck_distance_user.rst
+++ b/src/python/doc/bottleneck_distance_user.rst
@@ -65,3 +65,10 @@ The output is:
Bottleneck distance approximation = 0.81
Bottleneck distance value = 0.75
+
+Bibliography
+------------
+
+.. bibliography:: ../../biblio/bibliography.bib
+ :filter: docname in docnames
+ :style: unsrt
diff --git a/src/python/doc/cubical_complex_user.rst b/src/python/doc/cubical_complex_user.rst
index 93ca6b24..e8c94bf6 100644
--- a/src/python/doc/cubical_complex_user.rst
+++ b/src/python/doc/cubical_complex_user.rst
@@ -160,8 +160,8 @@ Examples.
End user programs are available in python/example/ folder.
Bibliography
-============
+------------
.. bibliography:: ../../biblio/bibliography.bib
- :filter: docnames
+ :filter: docname in docnames
:style: unsrt
diff --git a/src/python/doc/img/barycenter.png b/src/python/doc/img/barycenter.png
new file mode 100644
index 00000000..cad6af70
--- /dev/null
+++ b/src/python/doc/img/barycenter.png
Binary files differ
diff --git a/src/python/doc/index.rst b/src/python/doc/index.rst
index 3387a64f..c153cdfc 100644
--- a/src/python/doc/index.rst
+++ b/src/python/doc/index.rst
@@ -71,6 +71,7 @@ Wasserstein distance
.. include:: wasserstein_distance_sum.inc
+
Persistence representations
===========================
@@ -90,5 +91,5 @@ Bibliography
************
.. bibliography:: ../../biblio/bibliography.bib
- :filter: docnames
+ :filter: docname in docnames
:style: unsrt
diff --git a/src/python/doc/installation.rst b/src/python/doc/installation.rst
index d459145b..48425d5e 100644
--- a/src/python/doc/installation.rst
+++ b/src/python/doc/installation.rst
@@ -175,8 +175,8 @@ Documentation
To build the documentation, `sphinx-doc <http://www.sphinx-doc.org>`_ and
`sphinxcontrib-bibtex <https://sphinxcontrib-bibtex.readthedocs.io>`_ are
required. As the documentation is auto-tested, `CGAL`_, `Eigen`_,
-`Matplotlib`_, `NumPy`_ and `SciPy`_ are also mandatory to build the
-documentation.
+`Matplotlib`_, `NumPy`_, `POT`_, `Scikit-learn`_ and `SciPy`_ are
+also mandatory to build the documentation.
Run the following commands in a terminal:
@@ -192,8 +192,8 @@ CGAL
====
Some GUDHI modules (cf. :doc:`modules list </index>`), and few examples
-require CGAL, a C++ library that provides easy access to efficient and
-reliable geometric algorithms.
+require `CGAL <https://www.cgal.org/>`_, a C++ library that provides easy
+access to efficient and reliable geometric algorithms.
The procedure to install this library
diff --git a/src/python/doc/nerve_gic_complex_ref.rst b/src/python/doc/nerve_gic_complex_ref.rst
index abde2e8c..6a81b7af 100644
--- a/src/python/doc/nerve_gic_complex_ref.rst
+++ b/src/python/doc/nerve_gic_complex_ref.rst
@@ -12,3 +12,10 @@ Cover complexes reference manual
:show-inheritance:
.. automethod:: gudhi.CoverComplex.__init__
+
+Bibliography
+------------
+
+.. bibliography:: ../../biblio/bibliography.bib
+ :filter: docname in docnames
+ :style: unsrt
diff --git a/src/python/doc/nerve_gic_complex_user.rst b/src/python/doc/nerve_gic_complex_user.rst
index 9101f45d..f709ce91 100644
--- a/src/python/doc/nerve_gic_complex_user.rst
+++ b/src/python/doc/nerve_gic_complex_user.rst
@@ -313,3 +313,10 @@ the program outputs again SC.dot which gives the following visualization after u
:alt: Visualization with neato
Visualization with neato
+
+Bibliography
+------------
+
+.. bibliography:: ../../biblio/bibliography.bib
+ :filter: docname in docnames
+ :style: unsrt
diff --git a/src/python/doc/persistent_cohomology_user.rst b/src/python/doc/persistent_cohomology_user.rst
index 5f931b3a..506fa3a7 100644
--- a/src/python/doc/persistent_cohomology_user.rst
+++ b/src/python/doc/persistent_cohomology_user.rst
@@ -113,8 +113,8 @@ We provide several example files: run these examples with -h for details on thei
* :download:`tangential_complex_plain_homology_from_off_file_example.py <../example/tangential_complex_plain_homology_from_off_file_example.py>`
Bibliography
-============
+------------
.. bibliography:: ../../biblio/bibliography.bib
- :filter: docnames
+ :filter: docname in docnames
:style: unsrt
diff --git a/src/python/doc/rips_complex_user.rst b/src/python/doc/rips_complex_user.rst
index 8efb12e6..c4bbcfb6 100644
--- a/src/python/doc/rips_complex_user.rst
+++ b/src/python/doc/rips_complex_user.rst
@@ -347,3 +347,10 @@ until dimension 1 - one skeleton graph in other words), the output is:
points in the persistence diagram will be under the diagonal, and
bottleneck distance and persistence graphical tool will not work properly,
this is a known issue.
+
+Bibliography
+------------
+
+.. bibliography:: ../../biblio/bibliography.bib
+ :filter: docname in docnames
+ :style: unsrt
diff --git a/src/python/doc/simplex_tree_user.rst b/src/python/doc/simplex_tree_user.rst
index 3df7617f..1b272c35 100644
--- a/src/python/doc/simplex_tree_user.rst
+++ b/src/python/doc/simplex_tree_user.rst
@@ -66,3 +66,10 @@ The output is:
([1, 2], 4.0)
([1], 0.0)
([2], 4.0)
+
+Bibliography
+------------
+
+.. bibliography:: ../../biblio/bibliography.bib
+ :filter: docname in docnames
+ :style: unsrt
diff --git a/src/python/doc/tangential_complex_user.rst b/src/python/doc/tangential_complex_user.rst
index 852cf5b6..cf8199cc 100644
--- a/src/python/doc/tangential_complex_user.rst
+++ b/src/python/doc/tangential_complex_user.rst
@@ -197,8 +197,8 @@ The output is:
Bibliography
-============
+------------
.. bibliography:: ../../biblio/bibliography.bib
- :filter: docnames
+ :filter: docname in docnames
:style: unsrt
diff --git a/src/python/doc/wasserstein_distance_sum.inc b/src/python/doc/wasserstein_distance_sum.inc
index 0ff22035..f9308e5e 100644
--- a/src/python/doc/wasserstein_distance_sum.inc
+++ b/src/python/doc/wasserstein_distance_sum.inc
@@ -3,11 +3,11 @@
+-----------------------------------------------------------------+----------------------------------------------------------------------+------------------------------------------------------------------+
| .. figure:: | The q-Wasserstein distance measures the similarity between two | :Author: Theo Lacombe |
- | ../../doc/Bottleneck_distance/perturb_pd.png | persistence diagrams. It's the minimum value c that can be achieved | |
- | :figclass: align-center | by a perfect matching between the points of the two diagrams (+ all | :Since: GUDHI 3.1.0 |
- | | diagonal points), where the value of a matching is defined as the | |
- | Wasserstein distance is the q-th root of the sum of the | q-th root of the sum of all edge lengths to the power q. Edge lengths| :License: MIT |
- | edge lengths to the power q. | are measured in norm p, for :math:`1 \leq p \leq \infty`. | |
+ | ../../doc/Bottleneck_distance/perturb_pd.png | persistence diagrams using the sum of all edges lengths (instead of | |
+ | :figclass: align-center | the maximum). It allows to define sophisticated objects such as | :Since: GUDHI 3.1.0 |
+ | | barycenters of a family of persistence diagrams. | |
+ | | | :License: MIT |
+ | | | |
| | | :Requires: Python Optimal Transport (POT) :math:`\geq` 0.5.1 |
+-----------------------------------------------------------------+----------------------------------------------------------------------+------------------------------------------------------------------+
| * :doc:`wasserstein_distance_user` | |
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst
index a9b21fa5..c24da74d 100644
--- a/src/python/doc/wasserstein_distance_user.rst
+++ b/src/python/doc/wasserstein_distance_user.rst
@@ -9,10 +9,16 @@ Definition
.. include:: wasserstein_distance_sum.inc
-Functions
----------
-This implementation uses the Python Optimal Transport library and is based on
-ideas from "Large Scale Computation of Means and Cluster for Persistence
+The q-Wasserstein distance is defined as the minimal value achieved
+by a perfect matching between the points of the two diagrams (+ all
+diagonal points), where the value of a matching is defined as the
+q-th root of the sum of all edge lengths to the power q. Edge lengths
+are measured in norm p, for :math:`1 \leq p \leq \infty`.
+
+Distance Functions
+------------------
+This first implementation uses the Python Optimal Transport library and is based
+on ideas from "Large Scale Computation of Means and Cluster for Persistence
Diagrams via Optimal Transport" :cite:`10.5555/3327546.3327645`.
.. autofunction:: gudhi.wasserstein.wasserstein_distance
@@ -26,7 +32,7 @@ Morozov, and Arnur Nigmetov.
.. autofunction:: gudhi.hera.wasserstein_distance
Basic example
--------------
+*************
This example computes the 1-Wasserstein distance from 2 persistence diagrams with Euclidean ground metric.
Note that persistence diagrams must be submitted as (n x 2) numpy arrays and must not contain inf values.
@@ -48,9 +54,9 @@ The output is:
Wasserstein distance value = 1.45
-We can also have access to the optimal matching by letting `matching=True`.
+We can also have access to the optimal matching by letting `matching=True`.
It is encoded as a list of indices (i,j), meaning that the i-th point in X
-is mapped to the j-th point in Y.
+is mapped to the j-th point in Y.
An index of -1 represents the diagonal.
.. testcode::
@@ -78,9 +84,90 @@ An index of -1 represents the diagonal.
The output is:
.. testoutput::
-
+
Wasserstein distance value = 2.15
point 0 in dgm1 is matched to point 0 in dgm2
point 1 in dgm1 is matched to point 2 in dgm2
point 2 in dgm1 is matched to the diagonal
point 1 in dgm2 is matched to the diagonal
+
+Barycenters
+-----------
+
+A Frechet mean (or barycenter) is a generalization of the arithmetic
+mean in a non linear space such as the one of persistence diagrams.
+Given a set of persistence diagrams :math:`\mu_1 \dots \mu_n`, it is
+defined as a minimizer of the variance functional, that is of
+:math:`\mu \mapsto \sum_{i=1}^n d_2(\mu,\mu_i)^2`.
+where :math:`d_2` denotes the Wasserstein-2 distance between
+persistence diagrams.
+It is known to exist and is generically unique. However, an exact
+computation is in general untractable. Current implementation
+available is based on (Turner et al., 2014),
+:cite:`turner2014frechet`
+and uses an EM-scheme to
+provide a local minimum of the variance functional (somewhat similar
+to the Lloyd algorithm to estimate a solution to the k-means
+problem). The local minimum returned depends on the initialization of
+the barycenter.
+The combinatorial structure of the algorithm limits its
+performances on large scale problems (thousands of diagrams and of points
+per diagram).
+
+.. figure::
+ ./img/barycenter.png
+ :figclass: align-center
+
+ Illustration of Frechet mean between persistence
+ diagrams.
+
+
+.. autofunction:: gudhi.wasserstein.barycenter.lagrangian_barycenter
+
+Basic example
+*************
+
+This example estimates the Frechet mean (aka Wasserstein barycenter) between
+four persistence diagrams.
+It is initialized on the 4th diagram.
+As the algorithm is not convex, its output depends on the initialization and
+is only a local minimum of the objective function.
+Initialization can be either given as an integer (in which case the i-th
+diagram of the list is used as initial estimate) or as a diagram.
+If None, it will randomly select one of the diagrams of the list
+as initial estimate.
+Note that persistence diagrams must be submitted as
+(n x 2) numpy arrays and must not contain inf values.
+
+
+.. testcode::
+
+ from gudhi.wasserstein.barycenter import lagrangian_barycenter
+ import numpy as np
+
+ dg1 = np.array([[0.2, 0.5]])
+ dg2 = np.array([[0.2, 0.7]])
+ dg3 = np.array([[0.3, 0.6], [0.7, 0.8], [0.2, 0.3]])
+ dg4 = np.array([])
+ pdiagset = [dg1, dg2, dg3, dg4]
+ bary = lagrangian_barycenter(pdiagset=pdiagset,init=3)
+
+ message = "Wasserstein barycenter estimated:"
+ print(message)
+ print(bary)
+
+The output is:
+
+.. testoutput::
+
+ Wasserstein barycenter estimated:
+ [[0.27916667 0.55416667]
+ [0.7375 0.7625 ]
+ [0.2375 0.2625 ]]
+
+Bibliography
+------------
+
+.. bibliography:: ../../biblio/bibliography.bib
+ :filter: docname in docnames
+ :style: unsrt
diff --git a/src/python/doc/witness_complex_user.rst b/src/python/doc/witness_complex_user.rst
index 7087fa98..799f5444 100644
--- a/src/python/doc/witness_complex_user.rst
+++ b/src/python/doc/witness_complex_user.rst
@@ -128,8 +128,8 @@ Here is an example of constructing a strong witness complex filtration and compu
* :download:`euclidean_strong_witness_complex_diagram_persistence_from_off_file_example.py <../example/euclidean_strong_witness_complex_diagram_persistence_from_off_file_example.py>`
Bibliography
-============
+------------
.. bibliography:: ../../biblio/bibliography.bib
- :filter: docnames
+ :filter: docname in docnames
:style: unsrt
diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd
index 82f155de..595f22bb 100644
--- a/src/python/gudhi/simplex_tree.pxd
+++ b/src/python/gudhi/simplex_tree.pxd
@@ -57,6 +57,8 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi":
void remove_maximal_simplex(vector[int] simplex)
bool prune_above_filtration(double filtration)
bool make_filtration_non_decreasing()
+ void compute_extended_filtration()
+ vector[vector[pair[int, pair[double, double]]]] compute_extended_persistence_subdiagrams(vector[pair[int, pair[double, double]]] dgm, double min_persistence)
# Iterators over Simplex tree
pair[vector[int], double] get_simplex_and_filtration(Simplex_tree_simplex_handle f_simplex)
Simplex_tree_simplices_iterator get_simplices_iterator_begin()
diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx
index c01cc905..cc3753e1 100644
--- a/src/python/gudhi/simplex_tree.pyx
+++ b/src/python/gudhi/simplex_tree.pyx
@@ -396,6 +396,57 @@ cdef class SimplexTree:
"""
return self.get_ptr().make_filtration_non_decreasing()
+ def extend_filtration(self):
+ """ Extend filtration for computing extended persistence. This function only uses the
+ filtration values at the 0-dimensional simplices, and computes the extended persistence
+ diagram induced by the lower-star filtration computed with these values.
+
+ .. note::
+
+ Note that after calling this function, the filtration
+ values are actually modified within the Simplex_tree.
+ The function :func:`extended_persistence()<gudhi.SimplexTree.extended_persistence>`
+ retrieves the original values.
+
+ .. note::
+
+ Note that this code creates an extra vertex internally, so you should make sure that
+ the Simplex_tree does not contain a vertex with the largest possible value (i.e., 4294967295).
+ """
+ return self.get_ptr().compute_extended_filtration()
+
+ def extended_persistence(self, homology_coeff_field=11, min_persistence=0):
+ """This function retrieves good values for extended persistence, and separate the diagrams
+ into the Ordinary, Relative, Extended+ and Extended- subdiagrams.
+
+ :param homology_coeff_field: The homology coefficient field. Must be a
+ prime number. Default value is 11.
+ :type homology_coeff_field: int.
+ :param min_persistence: The minimum persistence value (i.e., the absolute value of the difference between the persistence diagram point coordinates) to take into
+ account (strictly greater than min_persistence). Default value is
+ 0.0.
+ Sets min_persistence to -1.0 to see all values.
+ :type min_persistence: float.
+ :returns: A list of four persistence diagrams in the format described in :func:`persistence()<gudhi.SimplexTree.persistence>`. The first one is Ordinary, the second one is Relative, the third one is Extended+ and the fourth one is Extended-. See https://link.springer.com/article/10.1007/s10208-008-9027-z and/or section 2.2 in https://link.springer.com/article/10.1007/s10208-017-9370-z for a description of these subtypes.
+
+ .. note::
+
+ This function should be called only if :func:`extend_filtration()<gudhi.SimplexTree.extend_filtration>` has been called first!
+
+ .. note::
+
+ The coordinates of the persistence diagram points might be a little different than the
+ original filtration values due to the internal transformation (scaling to [-2,-1]) that is
+ performed on these values during the computation of extended persistence.
+ """
+ cdef vector[pair[int, pair[double, double]]] persistence_result
+ if self.pcohptr != NULL:
+ del self.pcohptr
+ self.pcohptr = new Simplex_tree_persistence_interface(self.get_ptr(), False)
+ persistence_result = self.pcohptr.get_persistence(homology_coeff_field, -1.)
+ return self.get_ptr().compute_extended_persistence_subdiagrams(persistence_result, min_persistence)
+
+
def persistence(self, homology_coeff_field=11, min_persistence=0, persistence_dim_max = False):
"""This function returns the persistence of the simplicial complex.
diff --git a/src/python/gudhi/wasserstein/__init__.py b/src/python/gudhi/wasserstein/__init__.py
new file mode 100644
index 00000000..ed225ba4
--- /dev/null
+++ b/src/python/gudhi/wasserstein/__init__.py
@@ -0,0 +1 @@
+from .wasserstein import wasserstein_distance
diff --git a/src/python/gudhi/wasserstein/barycenter.py b/src/python/gudhi/wasserstein/barycenter.py
new file mode 100644
index 00000000..de7aea81
--- /dev/null
+++ b/src/python/gudhi/wasserstein/barycenter.py
@@ -0,0 +1,159 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Theo Lacombe
+#
+# Copyright (C) 2019 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+
+import ot
+import numpy as np
+import scipy.spatial.distance as sc
+
+from gudhi.wasserstein import wasserstein_distance
+
+
+def _mean(x, m):
+ '''
+ :param x: a list of 2D-points, off diagonal, x_0... x_{k-1}
+ :param m: total amount of points taken into account,
+ that is we have (m-k) copies of diagonal
+ :returns: the weighted mean of x with (m-k) copies of the diagonal
+ '''
+ k = len(x)
+ if k > 0:
+ w = np.mean(x, axis=0)
+ w_delta = (w[0] + w[1]) / 2 * np.ones(2)
+ return (k * w + (m-k) * w_delta) / m
+ else:
+ return np.array([0, 0])
+
+
+def lagrangian_barycenter(pdiagset, init=None, verbose=False):
+ '''
+ :param pdiagset: a list of ``numpy.array`` of shape `(n x 2)`
+ (`n` can variate), encoding a set of
+ persistence diagrams with only finite coordinates.
+ :param init: The initial value for barycenter estimate.
+ If ``None``, init is made on a random diagram from the dataset.
+ Otherwise, it can be an ``int``
+ (then initialization is made on ``pdiagset[init]``)
+ or a `(n x 2)` ``numpy.array`` enconding
+ a persistence diagram with `n` points.
+ :type init: ``int``, or (n x 2) ``np.array``
+ :param verbose: if ``True``, returns additional information about the
+ barycenter.
+ :type verbose: boolean
+ :returns: If not verbose (default), a ``numpy.array`` encoding
+ the barycenter estimate of pdiagset
+ (local minimum of the energy function).
+ If ``pdiagset`` is empty, returns ``None``.
+ If verbose, returns a couple ``(Y, log)``
+ where ``Y`` is the barycenter estimate,
+ and ``log`` is a ``dict`` that contains additional informations:
+
+ - `"groupings"`, a list of list of pairs ``(i,j)``.
+ Namely, ``G[k] = [...(i, j)...]``, where ``(i,j)`` indicates
+ that ``pdiagset[k][i]`` is matched to ``Y[j]``
+ if ``i = -1`` or ``j = -1``, it means they
+ represent the diagonal.
+
+ - `"energy"`, ``float`` representing the Frechet energy value obtained.
+ It is the mean of squared distances of observations to the output.
+
+ - `"nb_iter"`, ``int`` number of iterations performed before convergence of the algorithm.
+ '''
+ X = pdiagset # to shorten notations, not a copy
+ m = len(X) # number of diagrams we are averaging
+ if m == 0:
+ print("Warning: computing barycenter of empty diag set. Returns None")
+ return None
+
+ # store the number of off-diagonal point for each of the X_i
+ nb_off_diag = np.array([len(X_i) for X_i in X])
+ # Initialisation of barycenter
+ if init is None:
+ i0 = np.random.randint(m) # Index of first state for the barycenter
+ Y = X[i0].copy()
+ else:
+ if type(init)==int:
+ Y = X[init].copy()
+ else:
+ Y = init.copy()
+
+ nb_iter = 0
+
+ converged = False # stoping criterion
+ while not converged:
+ nb_iter += 1
+ K = len(Y) # current nb of points in Y (some might be on diagonal)
+ G = np.full((K, m), -1, dtype=int) # will store for each j, the (index)
+ # point matched in each other diagram
+ #(might be the diagonal).
+ # that is G[j, i] = k <=> y_j is matched to
+ # x_k in the diagram i-th diagram X[i]
+ updated_points = np.zeros((K, 2)) # will store the new positions of
+ # the points of Y.
+ # If points disappear, there thrown
+ # on [0,0] by default.
+ new_created_points = [] # will store potential new points.
+
+ # Step 1 : compute optimal matching (Y, X_i) for each X_i
+ # and create new points in Y if needed
+ for i in range(m):
+ _, indices = wasserstein_distance(Y, X[i], matching=True, order=2., internal_p=2.)
+ for y_j, x_i_j in indices:
+ if y_j >= 0: # we matched an off diagonal point to x_i_j...
+ if x_i_j >= 0: # ...which is also an off-diagonal point.
+ G[y_j, i] = x_i_j
+ else: # ...which is a diagonal point
+ G[y_j, i] = -1 # -1 stands for the diagonal (mask)
+ else: # We matched a diagonal point to x_i_j...
+ if x_i_j >= 0: # which is a off-diag point !
+ # need to create new point in Y
+ new_y = _mean(np.array([X[i][x_i_j]]), m)
+ # Average this point with (m-1) copies of Delta
+ new_created_points.append(new_y)
+
+ # Step 2 : Update current point position thanks to groupings computed
+ to_delete = []
+ for j in range(K):
+ matched_points = [X[i][G[j, i]] for i in range(m) if G[j, i] > -1]
+ new_y_j = _mean(matched_points, m)
+ if not np.array_equal(new_y_j, np.array([0,0])):
+ updated_points[j] = new_y_j
+ else: # this points is no longer of any use.
+ to_delete.append(j)
+ # we remove the point to be deleted now.
+ updated_points = np.delete(updated_points, to_delete, axis=0)
+
+ # we cannot converge if there have been new created points.
+ if new_created_points:
+ Y = np.concatenate((updated_points, new_created_points))
+ else:
+ # Step 3 : we check convergence
+ if np.array_equal(updated_points, Y):
+ converged = True
+ Y = updated_points
+
+
+ if verbose:
+ groupings = []
+ energy = 0
+ log = {}
+ n_y = len(Y)
+ for i in range(m):
+ cost, edges = wasserstein_distance(Y, X[i], matching=True, order=2., internal_p=2.)
+ groupings.append(edges)
+ energy += cost
+ log["groupings"] = groupings
+ energy = energy/m
+ log["energy"] = energy
+ log["nb_iter"] = nb_iter
+
+ return Y, log
+ else:
+ return Y
+
diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein/wasserstein.py
index 3dd993f9..35315939 100644
--- a/src/python/gudhi/wasserstein.py
+++ b/src/python/gudhi/wasserstein/wasserstein.py
@@ -9,6 +9,7 @@
import numpy as np
import scipy.spatial.distance as sc
+
try:
import ot
except ImportError:
@@ -29,9 +30,9 @@ def _build_dist_matrix(X, Y, order=2., internal_p=2.):
:param Y: (m x 2) numpy.array encoding the second diagram.
:param order: exponent for the Wasserstein metric.
:param internal_p: Ground metric (i.e. norm L^p).
- :returns: (n+1) x (m+1) np.array encoding the cost matrix C.
- For 0 <= i < n, 0 <= j < m, C[i,j] encodes the distance between X[i] and Y[j],
- while C[i, m] (resp. C[n, j]) encodes the distance (to the p) between X[i] (resp Y[j])
+ :returns: (n+1) x (m+1) np.array encoding the cost matrix C.
+ For 0 <= i < n, 0 <= j < m, C[i,j] encodes the distance between X[i] and Y[j],
+ while C[i, m] (resp. C[n, j]) encodes the distance (to the p) between X[i] (resp Y[j])
and its orthogonal projection onto the diagonal.
note also that C[n, m] = 0 (it costs nothing to move from the diagonal to the diagonal).
'''
@@ -58,7 +59,7 @@ def _perstot(X, order, internal_p):
:param X: (n x 2) numpy.array (points of a given diagram).
:param order: exponent for Wasserstein. Default value is 2.
:param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm).
- :returns: float, the total persistence of the diagram (that is, its distance to the empty diagram).
+ :returns: float, the total persistence of the diagram (that is, its distance to the empty diagram).
'''
Xdiag = _proj_on_diag(X)
return (np.sum(np.linalg.norm(X - Xdiag, ord=internal_p, axis=1)**order))**(1./order)
@@ -66,16 +67,16 @@ def _perstot(X, order, internal_p):
def wasserstein_distance(X, Y, matching=False, order=2., internal_p=2.):
'''
- :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points
+ :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points
(i.e. with infinite coordinate).
:param Y: (m x 2) numpy.array encoding the second diagram.
:param matching: if True, computes and returns the optimal matching between X and Y, encoded as
a (n x 2) np.array [...[i,j]...], meaning the i-th point in X is matched to
the j-th point in Y, with the convention (-1) represents the diagonal.
:param order: exponent for Wasserstein; Default value is 2.
- :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
+ :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2);
Default value is 2 (Euclidean norm).
- :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with
+ :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with
respect to the internal_p-norm as ground metric.
If matching is set to True, also returns the optimal matching between X and Y.
'''
diff --git a/src/python/include/Simplex_tree_interface.h b/src/python/include/Simplex_tree_interface.h
index 4a7062d6..1a18aed6 100644
--- a/src/python/include/Simplex_tree_interface.h
+++ b/src/python/include/Simplex_tree_interface.h
@@ -37,8 +37,12 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
using Filtered_simplices = std::vector<Simplex_and_filtration>;
using Skeleton_simplex_iterator = typename Base::Skeleton_simplex_iterator;
using Complex_simplex_iterator = typename Base::Complex_simplex_iterator;
+ using Extended_filtration_data = typename Base::Extended_filtration_data;
public:
+
+ Extended_filtration_data efd;
+
bool find_simplex(const Simplex& vh) {
return (Base::find(vh) != Base::null_simplex());
}
@@ -117,6 +121,42 @@ class Simplex_tree_interface : public Simplex_tree<SimplexTreeOptions> {
return cofaces;
}
+ void compute_extended_filtration() {
+ this->efd = this->extend_filtration();
+ this->initialize_filtration();
+ return;
+ }
+
+ std::vector<std::vector<std::pair<int, std::pair<Filtration_value, Filtration_value>>>> compute_extended_persistence_subdiagrams(const std::vector<std::pair<int, std::pair<Filtration_value, Filtration_value>>>& dgm, Filtration_value min_persistence){
+ std::vector<std::vector<std::pair<int, std::pair<Filtration_value, Filtration_value>>>> new_dgm(4);
+ for (unsigned int i = 0; i < dgm.size(); i++){
+ std::pair<Filtration_value, Extended_simplex_type> px = this->decode_extended_filtration(dgm[i].second.first, this->efd);
+ std::pair<Filtration_value, Extended_simplex_type> py = this->decode_extended_filtration(dgm[i].second.second, this->efd);
+ std::pair<int, std::pair<Filtration_value, Filtration_value>> pd_point = std::make_pair(dgm[i].first, std::make_pair(px.first, py.first));
+ if(std::abs(px.first - py.first) > min_persistence){
+ //Ordinary
+ if (px.second == Extended_simplex_type::UP && py.second == Extended_simplex_type::UP){
+ new_dgm[0].push_back(pd_point);
+ }
+ // Relative
+ else if (px.second == Extended_simplex_type::DOWN && py.second == Extended_simplex_type::DOWN){
+ new_dgm[1].push_back(pd_point);
+ }
+ else{
+ // Extended+
+ if (px.first < py.first){
+ new_dgm[2].push_back(pd_point);
+ }
+ //Extended-
+ else{
+ new_dgm[3].push_back(pd_point);
+ }
+ }
+ }
+ }
+ return new_dgm;
+ }
+
void create_persistence(Gudhi::Persistent_cohomology_interface<Base>* pcoh) {
Base::initialize_filtration();
pcoh = new Gudhi::Persistent_cohomology_interface<Base>(*this);
diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py
index f7848379..70b26e97 100755
--- a/src/python/test/test_simplex_tree.py
+++ b/src/python/test/test_simplex_tree.py
@@ -9,6 +9,7 @@
"""
from gudhi import SimplexTree
+import pytest
__author__ = "Vincent Rouvreau"
__copyright__ = "Copyright (C) 2016 Inria"
@@ -250,6 +251,87 @@ def test_make_filtration_non_decreasing():
assert st.filtration([3, 4]) == 2.0
assert st.filtration([4, 5]) == 2.0
+def test_extend_filtration():
+
+ # Inserted simplex:
+ # 5 4
+ # o o
+ # / \ /
+ # o o
+ # /2\ /3
+ # o o
+ # 1 0
+
+ st = SimplexTree()
+ st.insert([0,2])
+ st.insert([1,2])
+ st.insert([0,3])
+ st.insert([2,5])
+ st.insert([3,4])
+ st.insert([3,5])
+ st.assign_filtration([0], 1.)
+ st.assign_filtration([1], 2.)
+ st.assign_filtration([2], 3.)
+ st.assign_filtration([3], 4.)
+ st.assign_filtration([4], 5.)
+ st.assign_filtration([5], 6.)
+
+ assert list(st.get_filtration()) == [
+ ([0, 2], 0.0),
+ ([1, 2], 0.0),
+ ([0, 3], 0.0),
+ ([3, 4], 0.0),
+ ([2, 5], 0.0),
+ ([3, 5], 0.0),
+ ([0], 1.0),
+ ([1], 2.0),
+ ([2], 3.0),
+ ([3], 4.0),
+ ([4], 5.0),
+ ([5], 6.0)
+ ]
+
+ st.extend_filtration()
+
+ assert list(st.get_filtration()) == [
+ ([6], -3.0),
+ ([0], -2.0),
+ ([1], -1.8),
+ ([2], -1.6),
+ ([0, 2], -1.6),
+ ([1, 2], -1.6),
+ ([3], -1.4),
+ ([0, 3], -1.4),
+ ([4], -1.2),
+ ([3, 4], -1.2),
+ ([5], -1.0),
+ ([2, 5], -1.0),
+ ([3, 5], -1.0),
+ ([5, 6], 1.0),
+ ([4, 6], 1.2),
+ ([3, 6], 1.4),
+ ([3, 4, 6], 1.4),
+ ([3, 5, 6], 1.4),
+ ([2, 6], 1.6),
+ ([2, 5, 6], 1.6),
+ ([1, 6], 1.8),
+ ([1, 2, 6], 1.8),
+ ([0, 6], 2.0),
+ ([0, 2, 6], 2.0),
+ ([0, 3, 6], 2.0)
+ ]
+
+ dgms = st.extended_persistence(min_persistence=-1.)
+
+ assert dgms[0][0][1][0] == pytest.approx(2.)
+ assert dgms[0][0][1][1] == pytest.approx(3.)
+ assert dgms[1][0][1][0] == pytest.approx(5.)
+ assert dgms[1][0][1][1] == pytest.approx(4.)
+ assert dgms[2][0][1][0] == pytest.approx(1.)
+ assert dgms[2][0][1][1] == pytest.approx(6.)
+ assert dgms[3][0][1][0] == pytest.approx(6.)
+ assert dgms[3][0][1][1] == pytest.approx(1.)
+
def test_simplices_iterator():
st = SimplexTree()
diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py
new file mode 100755
index 00000000..f68c748e
--- /dev/null
+++ b/src/python/test/test_wasserstein_barycenter.py
@@ -0,0 +1,46 @@
+from gudhi.wasserstein.barycenter import lagrangian_barycenter
+import numpy as np
+
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Theo Lacombe
+
+ Copyright (C) 2019 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+__author__ = "Theo Lacombe"
+__copyright__ = "Copyright (C) 2019 Inria"
+__license__ = "MIT"
+
+
+def test_lagrangian_barycenter():
+
+ dg1 = np.array([[0.2, 0.5]])
+ dg2 = np.array([[0.2, 0.7]])
+ dg3 = np.array([[0.3, 0.6], [0.7, 0.8], [0.2, 0.3]])
+ dg4 = np.array([])
+ dg5 = np.array([])
+ dg6 = np.array([])
+ res = np.array([[0.27916667, 0.55416667], [0.7375, 0.7625], [0.2375, 0.2625]])
+
+ dg7 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]])
+ dg8 = np.array([[0., 4.], [4, 8]])
+
+ # error crit.
+ eps = 1e-7
+
+
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < eps
+ assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.empty(shape=(0,2)))
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < eps
+ Y, log = lagrangian_barycenter(pdiagset=[dg4, dg8], verbose=True)
+ assert np.linalg.norm(Y - np.array([[1,3], [5, 7]])) < eps
+ assert np.abs(log["energy"] - 2) < eps
+ assert np.array_equal(log["groupings"][0] , np.array([[0, -1], [1, -1]]))
+ assert np.array_equal(log["groupings"][1] , np.array([[0, 0], [1, 1]]))
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg8, dg4], init=np.array([[0.2, 0.6], [0.5, 0.7]]), verbose=False) - np.array([[1, 3], [5, 7]])) < eps
+ assert lagrangian_barycenter(pdiagset = []) is None
+
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 0d70e11a..7e0d0f5f 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -70,7 +70,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
assert np.array_equal(match , [[0, -1], [1, -1]])
match = wasserstein_distance(diag1, diag2, matching=True, internal_p=2., order=2.)[1]
assert np.array_equal(match, [[0, 0], [1, 1], [2, -1]])
-
+
def hera_wrap(delta):