diff options
-rw-r--r-- | biblio/bibliography.bib | 11 | ||||
-rw-r--r-- | src/python/CMakeLists.txt | 3 | ||||
-rw-r--r-- | src/python/doc/img/barycenter.png | bin | 0 -> 12433 bytes | |||
-rw-r--r-- | src/python/doc/index.rst | 1 | ||||
-rw-r--r-- | src/python/doc/wasserstein_distance_sum.inc | 10 | ||||
-rw-r--r-- | src/python/doc/wasserstein_distance_user.rst | 91 | ||||
-rw-r--r-- | src/python/gudhi/barycenter.py | 158 | ||||
-rw-r--r-- | src/python/gudhi/wasserstein.py | 1 | ||||
-rwxr-xr-x | src/python/test/test_wasserstein_barycenter.py | 46 |
9 files changed, 312 insertions, 9 deletions
diff --git a/biblio/bibliography.bib b/biblio/bibliography.bib index b017a07e..a57be224 100644 --- a/biblio/bibliography.bib +++ b/biblio/bibliography.bib @@ -1208,3 +1208,14 @@ numpages = {11}, location = {Montr\'{e}al, Canada}, series = {NIPS’18} } + +@article{turner2014frechet, + title={Fr{\'e}chet means for distributions of persistence diagrams}, + author={Turner, Katharine and Mileyko, Yuriy and Mukherjee, Sayan and Harer, John}, + journal={Discrete \& Computational Geometry}, + volume={52}, + number={1}, + pages={44--70}, + year={2014}, + publisher={Springer} +} diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index f00966a5..b7d43bea 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -56,6 +56,7 @@ if(PYTHONINTERP_FOUND) # Modules that should not be auto-imported in __init__.py set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'representations', ") set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'wasserstein', ") + set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'barycenter', ") set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'point_cloud', ") add_gudhi_debug_info("Python version ${PYTHON_VERSION_STRING}") @@ -217,6 +218,7 @@ if(PYTHONINTERP_FOUND) file(COPY "gudhi/persistence_graphical_tools.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") file(COPY "gudhi/representations" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/") file(COPY "gudhi/wasserstein.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") + file(COPY "gudhi/barycenter.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") file(COPY "gudhi/point_cloud" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") add_custom_command( @@ -389,6 +391,7 @@ if(PYTHONINTERP_FOUND) # Wasserstein if(OT_FOUND AND PYBIND11_FOUND) add_gudhi_py_test(test_wasserstein_distance) + add_gudhi_py_test(test_wasserstein_barycenter) endif() # Representations diff --git a/src/python/doc/img/barycenter.png b/src/python/doc/img/barycenter.png Binary files differnew file mode 100644 index 00000000..cad6af70 --- /dev/null +++ b/src/python/doc/img/barycenter.png diff --git a/src/python/doc/index.rst b/src/python/doc/index.rst index 3387a64f..0e484483 100644 --- a/src/python/doc/index.rst +++ b/src/python/doc/index.rst @@ -71,6 +71,7 @@ Wasserstein distance .. include:: wasserstein_distance_sum.inc + Persistence representations =========================== diff --git a/src/python/doc/wasserstein_distance_sum.inc b/src/python/doc/wasserstein_distance_sum.inc index 0ff22035..f10472bc 100644 --- a/src/python/doc/wasserstein_distance_sum.inc +++ b/src/python/doc/wasserstein_distance_sum.inc @@ -3,11 +3,11 @@ +-----------------------------------------------------------------+----------------------------------------------------------------------+------------------------------------------------------------------+ | .. figure:: | The q-Wasserstein distance measures the similarity between two | :Author: Theo Lacombe | - | ../../doc/Bottleneck_distance/perturb_pd.png | persistence diagrams. It's the minimum value c that can be achieved | | - | :figclass: align-center | by a perfect matching between the points of the two diagrams (+ all | :Since: GUDHI 3.1.0 | - | | diagonal points), where the value of a matching is defined as the | | - | Wasserstein distance is the q-th root of the sum of the | q-th root of the sum of all edge lengths to the power q. Edge lengths| :License: MIT | - | edge lengths to the power q. | are measured in norm p, for :math:`1 \leq p \leq \infty`. | | + | ../../doc/Bottleneck_distance/perturb_pd.png | persistence diagrams using the sum of all edges lengths (instead of | | + | :figclass: align-center | the maximum). It allows to define sophisticated objects such as | :Introduced in: GUDHI 3.1.0 | + | | barycenters of a family of persistence diagrams. | | + | Wasserstein distance is the q-th root of the sum of the | | :Copyright: MIT | + | edge lengths to the power q. | | | | | | :Requires: Python Optimal Transport (POT) :math:`\geq` 0.5.1 | +-----------------------------------------------------------------+----------------------------------------------------------------------+------------------------------------------------------------------+ | * :doc:`wasserstein_distance_user` | | diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index a9b21fa5..6de05afc 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -9,10 +9,16 @@ Definition .. include:: wasserstein_distance_sum.inc -Functions ---------- -This implementation uses the Python Optimal Transport library and is based on -ideas from "Large Scale Computation of Means and Cluster for Persistence +The q-Wasserstein distance is defined as the minimal value +by a perfect matching between the points of the two diagrams (+ all +diagonal points), where the value of a matching is defined as the +q-th root of the sum of all edge lengths to the power q. Edge lengths +are measured in norm p, for :math:`1 \leq p \leq \infty`. + +Distance Functions +------------------ +This first implementation uses the Python Optimal Transport library and is based +on ideas from "Large Scale Computation of Means and Cluster for Persistence Diagrams via Optimal Transport" :cite:`10.5555/3327546.3327645`. .. autofunction:: gudhi.wasserstein.wasserstein_distance @@ -84,3 +90,80 @@ The output is: point 1 in dgm1 is matched to point 2 in dgm2 point 2 in dgm1 is matched to the diagonal point 1 in dgm2 is matched to the diagonal + + +Barycenters +----------- + +A Frechet mean (or barycenter) is a generalization of the arithmetic +mean in a non linear space such as the one of persistence diagrams. +Given a set of persistence diagrams :math:`\mu_1 \dots \mu_n`, it is +defined as a minimizer of the variance functional, that is of +:math:`\mu \mapsto \sum_{i=1}^n d_2(\mu,\mu_i)^2`. +where :math:`d_2` denotes the Wasserstein-2 distance between +persistence diagrams. +It is known to exist and is generically unique. However, an exact +computation is in general untractable. Current implementation +available is based on (Turner et al., 2014), +:cite:`turner2014frechet` +and uses an EM-scheme to +provide a local minimum of the variance functional (somewhat similar +to the Lloyd algorithm to estimate a solution to the k-means +problem). The local minimum returned depends on the initialization of +the barycenter. +The combinatorial structure of the algorithm limits its +scaling on large scale problems (thousands of diagrams and of points +per diagram). + +.. figure:: + ./img/barycenter.png + :figclass: align-center + + Illustration of Frechet mean between persistence + diagrams. + + +.. autofunction:: gudhi.barycenter.lagrangian_barycenter + +Basic example +------------- + +This example computes the Frechet mean (aka Wasserstein barycenter) between +four persistence diagrams. +It is initialized on the 4th diagram. +As the algorithm is not convex, its output depends on the initialization and +is only a local minimum of the objective function. +Initialization can be either given as an integer (in which case the i-th +diagram of the list is used as initial estimate) or as a diagram. +If None, it will randomly select one of the diagram of the list +as initial estimate. +Note that persistence diagrams must be submitted as +(n x 2) numpy arrays and must not contain inf values. + + +.. testcode:: + + import gudhi.barycenter + import numpy as np + + dg1 = np.array([[0.2, 0.5]]) + dg2 = np.array([[0.2, 0.7]]) + dg3 = np.array([[0.3, 0.6], [0.7, 0.8], [0.2, 0.3]]) + dg4 = np.array([]) + pdiagset = [dg1, dg2, dg3, dg4] + bary = gudhi.wasserstein.barycenter.lagrangian_barycenter(pdiagset=pdiagset,init=3) + + message = "Wasserstein barycenter estimated:" + print(message) + print(bary) + +The output is: + +.. testoutput:: + + Wasserstein barycenter estimated: + [[0.27916667 0.55416667] + [0.7375 0.7625 ] + [0.2375 0.2625 ]] + + diff --git a/src/python/gudhi/barycenter.py b/src/python/gudhi/barycenter.py new file mode 100644 index 00000000..079bcc57 --- /dev/null +++ b/src/python/gudhi/barycenter.py @@ -0,0 +1,158 @@ +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# Author(s): Theo Lacombe +# +# Copyright (C) 2019 Inria +# +# Modification(s): +# - YYYY/MM Author: Description of the modification + + +import ot +import numpy as np +import scipy.spatial.distance as sc + +from gudhi.wasserstein import wasserstein_distance + + +def _mean(x, m): + ''' + :param x: a list of 2D-points, off diagonal, x_0... x_{k-1} + :param m: total amount of points taken into account, + that is we have (m-k) copies of diagonal + :returns: the weighted mean of x with (m-k) copies of the diagonal + ''' + k = len(x) + if k > 0: + w = np.mean(x, axis=0) + w_delta = (w[0] + w[1]) / 2 * np.ones(2) + return (k * w + (m-k) * w_delta) / m + else: + return np.array([0, 0]) + + +def lagrangian_barycenter(pdiagset, init=None, verbose=False): + ''' + :param pdiagset: a list of size m containing numpy.array of shape (n x 2) + (n can variate), encoding a set of + persistence diagrams with only finite coordinates. + :param init: The initial value for barycenter estimate. + If None, init is made on a random diagram from the dataset. + Otherwise, it must be an int + (then we init with diagset[init]) + or a (n x 2) numpy.array enconding + a persistence diagram with n points. + :param verbose: if True, returns additional information about the + barycenter. + :returns: If not verbose (default), a numpy.array encoding + the barycenter estimate of pdiagset + (local minima of the energy function). + If pdiagset is empty, returns None. + If verbose, returns a couple (Y, log) + where Y is the barycenter estimate, + and log is a dict that contains additional informations: + - groupings, a list of list of pairs (i,j), + That is, G[k] = [(i, j) ...], where (i,j) indicates + that X[i] is matched to Y[j] + if i = -1 or j = -1, it means they + represent the diagonal. + - energy, a float representing the Frechet + energy value obtained, + that is the mean of squared distances + of observations to the output. + - nb_iter, integer representing the number of iterations + performed before convergence of the algorithm. + ''' + X = pdiagset # to shorten notations, not a copy + m = len(X) # number of diagrams we are averaging + if m == 0: + print("Warning: computing barycenter of empty diag set. Returns None") + return None + + # store the number of off-diagonal point for each of the X_i + nb_off_diag = np.array([len(X_i) for X_i in X]) + # Initialisation of barycenter + if init is None: + i0 = np.random.randint(m) # Index of first state for the barycenter + Y = X[i0].copy() + else: + if type(init)==int: + Y = X[init].copy() + else: + Y = init.copy() + + nb_iter = 0 + + converged = False # stoping criterion + while not converged: + nb_iter += 1 + K = len(Y) # current nb of points in Y (some might be on diagonal) + G = np.full((K, m), -1, dtype=int) # will store for each j, the (index) + # point matched in each other diagram + #(might be the diagonal). + # that is G[j, i] = k <=> y_j is matched to + # x_k in the diagram i-th diagram X[i] + updated_points = np.zeros((K, 2)) # will store the new positions of + # the points of Y. + # If points disappear, there thrown + # on [0,0] by default. + new_created_points = [] # will store potential new points. + + # Step 1 : compute optimal matching (Y, X_i) for each X_i + # and create new points in Y if needed + for i in range(m): + _, indices = wasserstein_distance(Y, X[i], matching=True, order=2., internal_p=2.) + for y_j, x_i_j in indices: + if y_j >= 0: # we matched an off diagonal point to x_i_j... + if x_i_j >= 0: # ...which is also an off-diagonal point. + G[y_j, i] = x_i_j + else: # ...which is a diagonal point + G[y_j, i] = -1 # -1 stands for the diagonal (mask) + else: # We matched a diagonal point to x_i_j... + if x_i_j >= 0: # which is a off-diag point ! + # need to create new point in Y + new_y = _mean(np.array([X[i][x_i_j]]), m) + # Average this point with (m-1) copies of Delta + new_created_points.append(new_y) + + # Step 2 : Update current point position thanks to groupings computed + to_delete = [] + for j in range(K): + matched_points = [X[i][G[j, i]] for i in range(m) if G[j, i] > -1] + new_y_j = _mean(matched_points, m) + if not np.array_equal(new_y_j, np.array([0,0])): + updated_points[j] = new_y_j + else: # this points is no longer of any use. + to_delete.append(j) + # we remove the point to be deleted now. + updated_points = np.delete(updated_points, to_delete, axis=0) + + # we cannot converge if there have been new created points. + if new_created_points: + Y = np.concatenate((updated_points, new_created_points)) + else: + # Step 3 : we check convergence + if np.array_equal(updated_points, Y): + converged = True + Y = updated_points + + + if verbose: + groupings = [] + energy = 0 + log = {} + n_y = len(Y) + for i in range(m): + cost, edges = wasserstein_distance(Y, X[i], matching=True, order=2., internal_p=2.) + groupings.append(edges) + energy += cost + log["groupings"] = groupings + energy = energy/m + print(energy) + log["energy"] = energy + log["nb_iter"] = nb_iter + + return Y, log + else: + return Y + diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py index 3dd993f9..760eea8c 100644 --- a/src/python/gudhi/wasserstein.py +++ b/src/python/gudhi/wasserstein.py @@ -9,6 +9,7 @@ import numpy as np import scipy.spatial.distance as sc +import .barycenter try: import ot except ImportError: diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py new file mode 100755 index 00000000..f686aef5 --- /dev/null +++ b/src/python/test/test_wasserstein_barycenter.py @@ -0,0 +1,46 @@ +from gudhi.wasserstein.barycenter import lagrangian_barycenter +import numpy as np + +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Theo Lacombe + + Copyright (C) 2019 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +__author__ = "Theo Lacombe" +__copyright__ = "Copyright (C) 2019 Inria" +__license__ = "MIT" + + +def test_lagrangian_barycenter(): + + dg1 = np.array([[0.2, 0.5]]) + dg2 = np.array([[0.2, 0.7]]) + dg3 = np.array([[0.3, 0.6], [0.7, 0.8], [0.2, 0.3]]) + dg4 = np.array([]) + dg5 = np.array([]) + dg6 = np.array([]) + res = np.array([[0.27916667, 0.55416667], [0.7375, 0.7625], [0.2375, 0.2625]]) + + dg7 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]]) + dg8 = np.array([[0., 4.], [4, 8]]) + + # error crit. + eps = 1e-7 + + + assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < eps + assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.empty(shape=(0,2))) + assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < eps + Y, log = lagrangian_barycenter(pdiagset=[dg4, dg8], verbose=True) + assert np.linalg.norm(Y - np.array([[1,3], [5, 7]])) < eps + assert np.abs(log["energy"] - 2) < eps + assert np.array_equal(log["groupings"][0] , np.array([[0, -1], [1, -1]])) + assert np.array_equal(log["groupings"][1] , np.array([[0, 0], [1, 1]])) + assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg8, dg4], init=np.array([[0.2, 0.6], [0.5, 0.7]]), verbose=False) - np.array([[1, 3], [5, 7]])) < eps + assert lagrangian_barycenter(pdiagset = []) is None + |