summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/python/CMakeLists.txt3
-rw-r--r--src/python/doc/barycenter_sum.inc24
-rw-r--r--src/python/doc/barycenter_user.rst60
-rw-r--r--src/python/doc/img/barycenter.pngbin0 -> 12433 bytes
-rw-r--r--src/python/doc/index.rst5
-rw-r--r--src/python/gudhi/barycenter.py227
-rwxr-xr-xsrc/python/test/test_wasserstein_barycenter.py46
7 files changed, 365 insertions, 0 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 22af3ec9..9fa7b129 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -56,6 +56,7 @@ if(PYTHONINTERP_FOUND)
# Modules that should not be auto-imported in __init__.py
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'representations', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'wasserstein', ")
+ set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'barycenter', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'point_cloud', ")
add_gudhi_debug_info("Python version ${PYTHON_VERSION_STRING}")
@@ -227,6 +228,7 @@ if(PYTHONINTERP_FOUND)
file(COPY "gudhi/persistence_graphical_tools.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/representations" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/")
file(COPY "gudhi/wasserstein.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
+ file(COPY "gudhi/barycenter.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/point_cloud" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
add_custom_command(
@@ -399,6 +401,7 @@ if(PYTHONINTERP_FOUND)
# Wasserstein
if(OT_FOUND AND PYBIND11_FOUND)
add_gudhi_py_test(test_wasserstein_distance)
+ add_gudhi_py_test(test_wasserstein_barycenter)
endif()
# Representations
diff --git a/src/python/doc/barycenter_sum.inc b/src/python/doc/barycenter_sum.inc
new file mode 100644
index 00000000..da2bdd84
--- /dev/null
+++ b/src/python/doc/barycenter_sum.inc
@@ -0,0 +1,24 @@
+.. table::
+ :widths: 30 50 20
+
+ +-----------------------------------------------------------------+----------------------------------------------------------------------+------------------------------------------------------------------+
+ | .. figure:: | A Frechet mean (or barycenter) is a generalization of the arithmetic | :Author: Theo Lacombe |
+ | ./img/barycenter.png | mean in a non linear space such as the one of persistence diagrams. | |
+ | :figclass: align-center | Given a set of persistence diagrams :math:`\mu_1 \dots \mu_n`, it is | :Introduced in: GUDHI 3.1.0 |
+ | | defined as a minimizer of the variance functional, that is of | |
+ | Illustration of Frechet mean between persistence | :math:`\mu \mapsto \sum_{i=1}^n d_2(\mu,\mu_i)^2`. | :Copyright: MIT |
+ | diagrams. | where :math:`d_2` denotes the Wasserstein-2 distance between | |
+ | | persistence diagrams. | |
+ | | It is known to exist and is generically unique. However, an exact | |
+ | | computation is in general untractable. Current implementation | :Requires: Python Optimal Transport (POT) :math:`\geq` 0.5.1 |
+ | | available is based on [Turner et al, 2014], and uses an EM-scheme to | |
+ | | provide a local minimum of the variance functional (somewhat similar | |
+ | | to the Lloyd algorithm to estimate a solution to the k-means | |
+ | | problem). The local minimum returned depends on the initialization of| |
+ | | the barycenter. | |
+ | | The combinatorial structure of the algorithm limits its | |
+ | | scaling on large scale problems (thousands of diagrams and of points | |
+ | | per diagram). | |
+ +-----------------------------------------------------------------+----------------------------------------------------------------------+------------------------------------------------------------------+
+ | * :doc:`barycenter_user` | |
+ +-----------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+
diff --git a/src/python/doc/barycenter_user.rst b/src/python/doc/barycenter_user.rst
new file mode 100644
index 00000000..83e9bebb
--- /dev/null
+++ b/src/python/doc/barycenter_user.rst
@@ -0,0 +1,60 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+Barycenter user manual
+================================
+Definition
+----------
+
+.. include:: barycenter_sum.inc
+
+This implementation is based on ideas from "Frechet means for distribution of
+persistence diagrams", Turner et al. 2014.
+
+Function
+--------
+.. autofunction:: gudhi.barycenter.lagrangian_barycenter
+
+
+Basic example
+-------------
+
+This example computes the Frechet mean (aka Wasserstein barycenter) between
+four persistence diagrams.
+It is initialized on the 4th diagram.
+As the algorithm is not convex, its output depends on the initialization and
+is only a local minimum of the objective function.
+Initialization can be either given as an integer (in which case the i-th
+diagram of the list is used as initial estimate) or as a diagram.
+If None, it will randomly select one of the diagram of the list
+as initial estimate.
+Note that persistence diagrams must be submitted as
+(n x 2) numpy arrays and must not contain inf values.
+
+.. testcode::
+
+ import gudhi.barycenter
+ import numpy as np
+
+ dg1 = np.array([[0.2, 0.5]])
+ dg2 = np.array([[0.2, 0.7]])
+ dg3 = np.array([[0.3, 0.6], [0.7, 0.8], [0.2, 0.3]])
+ dg4 = np.array([])
+ pdiagset = [dg1, dg2, dg3, dg4]
+ bary = gudhi.barycenter.lagrangian_barycenter(pdiagset=pdiagset,init=3)
+
+ message = "Wasserstein barycenter estimated:"
+ print(message)
+ print(bary)
+
+The output is:
+
+.. testoutput::
+
+ Wasserstein barycenter estimated:
+ [[0.27916667 0.55416667]
+ [0.7375 0.7625 ]
+ [0.2375 0.2625 ]]
+
+
diff --git a/src/python/doc/img/barycenter.png b/src/python/doc/img/barycenter.png
new file mode 100644
index 00000000..cad6af70
--- /dev/null
+++ b/src/python/doc/img/barycenter.png
Binary files differ
diff --git a/src/python/doc/index.rst b/src/python/doc/index.rst
index 3387a64f..96cd3513 100644
--- a/src/python/doc/index.rst
+++ b/src/python/doc/index.rst
@@ -71,6 +71,11 @@ Wasserstein distance
.. include:: wasserstein_distance_sum.inc
+Barycenter
+============
+
+.. include:: barycenter_sum.inc
+
Persistence representations
===========================
diff --git a/src/python/gudhi/barycenter.py b/src/python/gudhi/barycenter.py
new file mode 100644
index 00000000..517cdb2f
--- /dev/null
+++ b/src/python/gudhi/barycenter.py
@@ -0,0 +1,227 @@
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Theo Lacombe
+#
+# Copyright (C) 2019 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+
+import ot
+import numpy as np
+import scipy.spatial.distance as sc
+
+from gudhi.wasserstein import _build_dist_matrix, _perstot
+
+
+
+def _mean(x, m):
+ '''
+ :param x: a list of 2D-points, off diagonal, x_0... x_{k-1}
+ :param m: total amount of points taken into account,
+ that is we have (m-k) copies of diagonal
+ :returns: the weighted mean of x with (m-k) copies of the diagonal
+ '''
+ k = len(x)
+ if k > 0:
+ w = np.mean(x, axis=0)
+ w_delta = (w[0] + w[1]) / 2 * np.ones(2)
+ return (k * w + (m-k) * w_delta) / m
+ else:
+ return np.array([0, 0])
+
+
+def _optimal_matching(X, Y, withcost=False):
+ '''
+ :param X: numpy.array of size (n x 2)
+ :param Y: numpy.array of size (m x 2)
+ :param withcost: returns also the cost corresponding to the optimal matching
+ :returns: numpy.array of shape (k x 2) encoding the list of edges
+ in the optimal matching.
+ That is, [[i, j] ...], where (i,j) indicates
+ that X[i] is matched to Y[j]
+ if i >= len(X) or j >= len(Y), it means they
+ represent the diagonal.
+ They will be encoded by -1 afterwards.
+
+ NOTE : this code will be removed for final merge,
+ and wasserstein.optimal_matching will be used instead.
+ '''
+
+ n = len(X)
+ m = len(Y)
+ # Start by handling empty diagrams. Could it be shorten?
+ if X.size == 0: # X is empty
+ if Y.size == 0: # Y is empty
+ res = np.array([[0,0]]) # the diagonal is matched to the diagonal
+ if withcost:
+ return res, 0
+ else:
+ return res
+ else: # X is empty but not Y
+ res = np.array([[0, i] for i in range(m)])
+ cost = _perstot(Y, order=2, internal_p=2)**2
+ if withcost:
+ return res, cost
+ else:
+ return res
+ elif Y.size == 0: # X is not empty but Y is empty
+ res = np.array([[i,0] for i in range(n)])
+ cost = _perstot(X, order=2, internal_p=2)**2
+ if withcost:
+ return res, cost
+ else:
+ return res
+
+ # we know X, Y are not empty diags now
+ M = _build_dist_matrix(X, Y, order=2, internal_p=2)
+
+ a = np.ones(n+1)
+ a[-1] = m
+ b = np.ones(m+1)
+ b[-1] = n
+ P = ot.emd(a=a, b=b, M=M)
+ # Note : it seems POT returns a permutation matrix in this situation,
+ # ie a vertex of the constraint set (generically true).
+ if withcost:
+ cost = np.sum(np.multiply(P, M))
+ P[P < 0.5] = 0 # dirty trick to avoid some numerical issues... to improve.
+ res = np.argwhere(P)
+
+ # return the list of (i,j) such that P[i,j] > 0,
+ #i.e. x_i is matched to y_j (should it be the diag).
+ if withcost:
+ return res, cost
+ return res
+
+
+def lagrangian_barycenter(pdiagset, init=None, verbose=False):
+ '''
+ :param pdiagset: a list of size m containing numpy.array of shape (n x 2)
+ (n can variate), encoding a set of
+ persistence diagrams with only finite coordinates.
+ :param init: The initial value for barycenter estimate.
+ If None, init is made on a random diagram from the dataset.
+ Otherwise, it must be an int
+ (then we init with diagset[init])
+ or a (n x 2) numpy.array enconding
+ a persistence diagram with n points.
+ :param verbose: if True, returns additional information about the
+ barycenter.
+ :returns: If not verbose (default), a numpy.array encoding
+ the barycenter estimate of pdiagset
+ (local minima of the energy function).
+ If pdiagset is empty, returns None.
+ If verbose, returns a couple (Y, log)
+ where Y is the barycenter estimate,
+ and log is a dict that contains additional informations:
+ - groupings, a list of list of pairs (i,j),
+ That is, G[k] = [(i, j) ...], where (i,j) indicates
+ that X[i] is matched to Y[j]
+ if i = -1 or j = -1, it means they
+ represent the diagonal.
+ - energy, a float representing the Frechet
+ energy value obtained,
+ that is the mean of squared distances
+ of observations to the output.
+ - nb_iter, integer representing the number of iterations
+ performed before convergence of the algorithm.
+ '''
+ X = pdiagset # to shorten notations, not a copy
+ m = len(X) # number of diagrams we are averaging
+ if m == 0:
+ print("Warning: computing barycenter of empty diag set. Returns None")
+ return None
+
+ # store the number of off-diagonal point for each of the X_i
+ nb_off_diag = np.array([len(X_i) for X_i in X])
+ # Initialisation of barycenter
+ if init is None:
+ i0 = np.random.randint(m) # Index of first state for the barycenter
+ Y = X[i0].copy()
+ else:
+ if type(init)==int:
+ Y = X[init].copy()
+ else:
+ Y = init.copy()
+
+ nb_iter = 0
+
+ converged = False # stoping criterion
+ while not converged:
+ nb_iter += 1
+ K = len(Y) # current nb of points in Y (some might be on diagonal)
+ G = np.full((K, m), -1, dtype=int) # will store for each j, the (index)
+ # point matched in each other diagram
+ #(might be the diagonal).
+ # that is G[j, i] = k <=> y_j is matched to
+ # x_k in the diagram i-th diagram X[i]
+ updated_points = np.zeros((K, 2)) # will store the new positions of
+ # the points of Y.
+ # If points disappear, there thrown
+ # on [0,0] by default.
+ new_created_points = [] # will store potential new points.
+
+ # Step 1 : compute optimal matching (Y, X_i) for each X_i
+ # and create new points in Y if needed
+ for i in range(m):
+ indices = _optimal_matching(Y, X[i])
+ for y_j, x_i_j in indices:
+ if y_j < K: # we matched an off diagonal point to x_i_j...
+ # ...which is also an off-diagonal point.
+ if x_i_j < nb_off_diag[i]:
+ G[y_j, i] = x_i_j
+ else: # ...which is a diagonal point
+ G[y_j, i] = -1 # -1 stands for the diagonal (mask)
+ else: # We matched a diagonal point to x_i_j...
+ if x_i_j < nb_off_diag[i]: # which is a off-diag point !
+ # need to create new point in Y
+ new_y = _mean(np.array([X[i][x_i_j]]), m)
+ # Average this point with (m-1) copies of Delta
+ new_created_points.append(new_y)
+
+ # Step 2 : Update current point position thanks to groupings computed
+ to_delete = []
+ for j in range(K):
+ matched_points = [X[i][G[j, i]] for i in range(m) if G[j, i] > -1]
+ new_y_j = _mean(matched_points, m)
+ if not np.array_equal(new_y_j, np.array([0,0])):
+ updated_points[j] = new_y_j
+ else: # this points is no longer of any use.
+ to_delete.append(j)
+ # we remove the point to be deleted now.
+ updated_points = np.delete(updated_points, to_delete, axis=0)
+
+ # we cannot converge if there have been new created points.
+ if new_created_points:
+ Y = np.concatenate((updated_points, new_created_points))
+ else:
+ # Step 3 : we check convergence
+ if np.array_equal(updated_points, Y):
+ converged = True
+ Y = updated_points
+
+
+ if verbose:
+ groupings = []
+ energy = 0
+ log = {}
+ n_y = len(Y)
+ for i in range(m):
+ edges, cost = _optimal_matching(Y, X[i], withcost=True)
+ n_x = len(X[i])
+ G = edges[np.where(edges[:,0]<n_y)]
+ idx = np.where(G[:,1] >= n_x)
+ G[idx,1] = -1 # -1 will encode the diagonal
+ groupings.append(G)
+ energy += cost
+ log["groupings"] = groupings
+ energy = energy/m
+ log["energy"] = energy
+ log["nb_iter"] = nb_iter
+
+ return Y, log
+ else:
+ return Y
+
diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py
new file mode 100755
index 00000000..5167cb84
--- /dev/null
+++ b/src/python/test/test_wasserstein_barycenter.py
@@ -0,0 +1,46 @@
+from gudhi.barycenter import lagrangian_barycenter
+import numpy as np
+
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Theo Lacombe
+
+ Copyright (C) 2019 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+__author__ = "Theo Lacombe"
+__copyright__ = "Copyright (C) 2019 Inria"
+__license__ = "MIT"
+
+
+def test_lagrangian_barycenter():
+
+ dg1 = np.array([[0.2, 0.5]])
+ dg2 = np.array([[0.2, 0.7]])
+ dg3 = np.array([[0.3, 0.6], [0.7, 0.8], [0.2, 0.3]])
+ dg4 = np.array([])
+ dg5 = np.array([])
+ dg6 = np.array([])
+ res = np.array([[0.27916667, 0.55416667], [0.7375, 0.7625], [0.2375, 0.2625]])
+
+ dg7 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]])
+ dg8 = np.array([[0., 4.], [4, 8]])
+
+ # error crit.
+ eps = 1e-7
+
+
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < eps
+ assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.empty(shape=(0,2)))
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < eps
+ Y, log = lagrangian_barycenter(pdiagset=[dg4, dg8], verbose=True)
+ assert np.linalg.norm(Y - np.array([[1,3], [5, 7]])) < eps
+ assert np.abs(log["energy"] - 4) < eps
+ assert np.array_equal(log["groupings"][0] , np.array([[0, -1], [1, -1]]))
+ assert np.array_equal(log["groupings"][1] , np.array([[0, 0], [1, 1]]))
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg8, dg4], init=np.array([[0.2, 0.6], [0.5, 0.7]]), verbose=False) - np.array([[1, 3], [5, 7]])) < eps
+ assert lagrangian_barycenter(pdiagset = []) is None
+