summaryrefslogtreecommitdiff
path: root/src/python
diff options
context:
space:
mode:
Diffstat (limited to 'src/python')
-rw-r--r--src/python/CMakeLists.txt3
-rw-r--r--src/python/doc/barycenter_sum.inc24
-rw-r--r--src/python/doc/barycenter_user.rst51
-rw-r--r--src/python/doc/img/barycenter.pngbin0 -> 12433 bytes
-rw-r--r--src/python/gudhi/barycenter.py227
-rwxr-xr-xsrc/python/test/test_wasserstein_barycenter.py33
6 files changed, 338 insertions, 0 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 090a7446..8c36f7ee 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -56,6 +56,7 @@ if(PYTHONINTERP_FOUND)
# Modules that should not be auto-imported in __init__.py
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'representations', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'wasserstein', ")
+ set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'barycenter', ")
add_gudhi_debug_info("Python version ${PYTHON_VERSION_STRING}")
add_gudhi_debug_info("Cython version ${CYTHON_VERSION}")
@@ -226,6 +227,7 @@ endif(CGAL_FOUND)
file(COPY "gudhi/persistence_graphical_tools.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/representations" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/")
file(COPY "gudhi/wasserstein.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
+ file(COPY "gudhi/barycenter.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
add_custom_command(
OUTPUT gudhi.so
@@ -397,6 +399,7 @@ endif(CGAL_FOUND)
# Wasserstein
if(OT_FOUND AND PYBIND11_FOUND)
add_gudhi_py_test(test_wasserstein_distance)
+ add_gudhi_py_test(test_wasserstein_barycenter)
endif()
# Representations
diff --git a/src/python/doc/barycenter_sum.inc b/src/python/doc/barycenter_sum.inc
new file mode 100644
index 00000000..da2bdd84
--- /dev/null
+++ b/src/python/doc/barycenter_sum.inc
@@ -0,0 +1,24 @@
+.. table::
+ :widths: 30 50 20
+
+ +-----------------------------------------------------------------+----------------------------------------------------------------------+------------------------------------------------------------------+
+ | .. figure:: | A Frechet mean (or barycenter) is a generalization of the arithmetic | :Author: Theo Lacombe |
+ | ./img/barycenter.png | mean in a non linear space such as the one of persistence diagrams. | |
+ | :figclass: align-center | Given a set of persistence diagrams :math:`\mu_1 \dots \mu_n`, it is | :Introduced in: GUDHI 3.1.0 |
+ | | defined as a minimizer of the variance functional, that is of | |
+ | Illustration of Frechet mean between persistence | :math:`\mu \mapsto \sum_{i=1}^n d_2(\mu,\mu_i)^2`. | :Copyright: MIT |
+ | diagrams. | where :math:`d_2` denotes the Wasserstein-2 distance between | |
+ | | persistence diagrams. | |
+ | | It is known to exist and is generically unique. However, an exact | |
+ | | computation is in general untractable. Current implementation | :Requires: Python Optimal Transport (POT) :math:`\geq` 0.5.1 |
+ | | available is based on [Turner et al, 2014], and uses an EM-scheme to | |
+ | | provide a local minimum of the variance functional (somewhat similar | |
+ | | to the Lloyd algorithm to estimate a solution to the k-means | |
+ | | problem). The local minimum returned depends on the initialization of| |
+ | | the barycenter. | |
+ | | The combinatorial structure of the algorithm limits its | |
+ | | scaling on large scale problems (thousands of diagrams and of points | |
+ | | per diagram). | |
+ +-----------------------------------------------------------------+----------------------------------------------------------------------+------------------------------------------------------------------+
+ | * :doc:`barycenter_user` | |
+ +-----------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+
diff --git a/src/python/doc/barycenter_user.rst b/src/python/doc/barycenter_user.rst
new file mode 100644
index 00000000..714d807e
--- /dev/null
+++ b/src/python/doc/barycenter_user.rst
@@ -0,0 +1,51 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+Barycenter user manual
+================================
+Definition
+----------
+
+.. include:: barycenter_sum.inc
+
+This implementation is based on ideas from "Frechet means for distribution of persistence diagrams", Turner et al. 2014.
+
+Function
+--------
+.. autofunction:: gudhi.barycenter.lagrangian_barycenter
+
+
+Basic example
+-------------
+
+This example computes the Frechet mean (aka Wasserstein barycenter) between four persistence diagrams.
+It is initialized on the 4th diagram, which is the empty diagram. It is encoded by np.array([]).
+Note that persistence diagrams must be submitted as (n x 2) numpy arrays and must not contain inf values.
+
+.. testcode::
+
+ import gudhi.barycenter
+ import numpy as np
+
+ dg1 = np.array([[0.2, 0.5]])
+ dg2 = np.array([[0.2, 0.7]])
+ dg3 = np.array([[0.3, 0.6], [0.7, 0.8], [0.2, 0.3]])
+ dg4 = np.array([])
+
+ bary = gudhi.barycenter.lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3)
+
+ message = "Wasserstein barycenter estimated:"
+ print(message)
+ print(bary)
+
+The output is:
+
+.. testoutput::
+
+ Wasserstein barycenter estimated:
+ [[0.27916667 0.55416667]
+ [0.7375 0.7625 ]
+ [0.2375 0.2625 ]]
+
+
diff --git a/src/python/doc/img/barycenter.png b/src/python/doc/img/barycenter.png
new file mode 100644
index 00000000..cad6af70
--- /dev/null
+++ b/src/python/doc/img/barycenter.png
Binary files differ
diff --git a/src/python/gudhi/barycenter.py b/src/python/gudhi/barycenter.py
new file mode 100644
index 00000000..11098afe
--- /dev/null
+++ b/src/python/gudhi/barycenter.py
@@ -0,0 +1,227 @@
+import ot
+import numpy as np
+import scipy.spatial.distance as sc
+
+
+# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+# Author(s): Theo Lacombe
+#
+# Copyright (C) 2019 Inria
+#
+# Modification(s):
+# - YYYY/MM Author: Description of the modification
+
+
+def _proj_on_diag(w):
+ '''
+ Util function to project a point on the diag.
+ '''
+ return np.array([(w[0] + w[1])/2 , (w[0] + w[1])/2])
+
+
+def _proj_on_diag_array(X):
+ '''
+ :param X: (n x 2) array encoding the points of a persistent diagram.
+ :returns: (n x 2) array encoding the (respective orthogonal) projections of the points onto the diagonal
+ '''
+ Z = (X[:,0] + X[:,1]) / 2.
+ return np.array([Z , Z]).T
+
+
+def _build_dist_matrix(X, Y, p=2., q=2.):
+ '''
+ :param X: (n x 2) numpy.array encoding the (points of the) first diagram.
+ :param Y: (m x 2) numpy.array encoding the second diagram.
+ :param q: Ground metric (i.e. norm l_q).
+ :param p: exponent for the Wasserstein metric.
+ :returns: (n+1) x (m+1) np.array encoding the cost matrix C.
+ For 1 <= i <= n, 1 <= j <= m, C[i,j] encodes the distance between X[i] and Y[j], while C[i, m+1] (resp. C[n+1, j]) encodes the distance (to the p) between X[i] (resp Y[j]) and its orthogonal proj onto the diagonal.
+ note also that C[n+1, m+1] = 0 (it costs nothing to move from the diagonal to the diagonal).
+ Note that for lagrangian_barycenter, one must use p=q=2.
+ '''
+ Xdiag = _proj_on_diag_array(X)
+ Ydiag = _proj_on_diag_array(Y)
+ if np.isinf(q):
+ C = sc.cdist(X, Y, metric='chebyshev')**p
+ Cxd = np.linalg.norm(X - Xdiag, ord=q, axis=1)**p
+ Cdy = np.linalg.norm(Y - Ydiag, ord=q, axis=1)**p
+ else:
+ C = sc.cdist(X,Y, metric='minkowski', p=q)**p
+ Cxd = np.linalg.norm(X - Xdiag, ord=q, axis=1)**p
+ Cdy = np.linalg.norm(Y - Ydiag, ord=q, axis=1)**p
+ Cf = np.hstack((C, Cxd[:,None]))
+ Cdy = np.append(Cdy, 0)
+
+ Cf = np.vstack((Cf, Cdy[None,:]))
+
+ return Cf
+
+
+def _optimal_matching(X, Y, withcost=False):
+ """
+ :param X: numpy.array of size (n x 2)
+ :param Y: numpy.array of size (m x 2)
+ :param withcost: returns also the cost corresponding to this optimal matching
+ :returns: numpy.array of shape (k x 2) encoding the list of edges in the optimal matching.
+ That is, [(i, j) ...], where (i,j) indicates that X[i] is matched to Y[j]
+ if i > len(X) or j > len(Y), it means they represent the diagonal.
+
+ """
+
+ n = len(X)
+ m = len(Y)
+ if X.size == 0: # X is empty
+ if Y.size == 0: # Y is empty
+ return np.array([[0,0]]) # the diagonal is matched to the diagonal and that's it...
+ else:
+ return np.column_stack([np.zeros(m+1, dtype=int), np.arange(m+1, dtype=int)])
+ elif Y.size == 0: # X is not empty but Y is empty
+ return np.column_stack([np.zeros(n+1, dtype=int), np.arange(n+1, dtype=int)])
+
+ # we know X, Y are not empty diags now
+ M = _build_dist_matrix(X, Y)
+
+ a = np.full(n+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here.
+ a[-1] = a[-1] * m # normalized so that we have a probability measure, required by POT
+ b = np.full(m+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here.
+ b[-1] = b[-1] * n # so that we have a probability measure, required by POT
+ P = ot.emd(a=a, b=b, M=M)*(n+m)
+ # Note : it seems POT return a permutation matrix in this situation, ie a vertex of the constraint set (generically true).
+ if withcost:
+ cost = np.sqrt(np.sum(np.multiply(P, M)))
+ P[P < 0.5] = 0 # dirty trick to avoid some numerical issues... to be improved.
+ # return the list of (i,j) such that P[i,j] > 0, i.e. x_i is matched to y_j (should it be the diag).
+ res = np.nonzero(P)
+
+ if withcost:
+ return np.column_stack(res), cost
+
+ return np.column_stack(res)
+
+
+def _mean(x, m):
+ """
+ :param x: a list of 2D-points, off diagonal, x_0... x_{k-1}
+ :param m: total amount of points taken into account, that is we have (m-k) copies of diagonal
+ :returns: the weighted mean of x with (m-k) copies of the diagonal
+ """
+ k = len(x)
+ if k > 0:
+ w = np.mean(x, axis=0)
+ w_delta = _proj_on_diag(w)
+ return (k * w + (m-k) * w_delta) / m
+ else:
+ return np.array([0, 0])
+
+
+def lagrangian_barycenter(pdiagset, init=None, verbose=False):
+ """
+ Compute the estimated barycenter computed with the algorithm provided
+ by Turner et al (2014).
+ It is a local minimum of the corresponding Frechet function.
+ :param pdiagset: a list of size N containing numpy.array of shape (n x 2)
+ (n can variate), encoding a set of
+ persistence diagrams with only finite coordinates.
+ :param init: The initial value for barycenter estimate.
+ If None, init is made on a random diagram from the dataset.
+ Otherwise, it must be an int (then we init with diagset[init])
+ or a (n x 2) numpy.array enconding a persistence diagram with n points.
+ :param verbose: if True, returns additional information about the
+ barycenter.
+ :returns: If not verbose (default), a numpy.array encoding
+ the barycenter estimate (local minima of the energy function).
+ If verbose, returns a couple (Y, log)
+ where Y is the barycenter estimate,
+ and log is a dict that contains additional informations:
+ - assigments, a list of list of pairs (i,j),
+ That is, a[k] = [(i, j) ...], where (i,j) indicates that X[i] is matched to Y[j]
+ if i > len(X) or j > len(Y), it means they represent the diagonal.
+ - energy, a float representing the Frechet mean value obtained.
+ """
+ X = pdiagset # to shorten notations, not a copy
+ m = len(X) # number of diagrams we are averaging
+ if m == 0:
+ print("Warning: computing barycenter of empty diag set. Returns None")
+ return None
+
+ nb_off_diag = np.array([len(X_i) for X_i in X]) # store the number of off-diagonal point for each of the X_i
+
+ # Initialisation of barycenter
+ if init is None:
+ i0 = np.random.randint(m) # Index of first state for the barycenter
+ Y = X[i0].copy() #copy() ensure that we do not modify X[i0]
+ else:
+ if type(init)==int:
+ Y = X[init].copy()
+ else:
+ Y = init.copy()
+
+ converged = False # stoping criterion
+ while not converged:
+ K = len(Y) # current nb of points in Y (some might be on diagonal)
+ G = np.zeros((K, m), dtype=int)-1 # will store for each j, the (index) point matched in each other diagram (might be the diagonal).
+ # that is G[j, i] = k <=> y_j is matched to
+ # x_k in the diagram i-th diagram X[i]
+ updated_points = np.zeros((K, 2)) # will store the new positions of
+ # the points of Y.
+ # If points disappear, there thrown
+ # on [0,0] by default.
+ new_created_points = [] # will store potential new points.
+
+ # Step 1 : compute optimal matching (Y, X_i) for each X_i
+ # and create new points in Y if needed
+ for i in range(m):
+ indices = _optimal_matching(Y, X[i])
+ for y_j, x_i_j in indices:
+ if y_j < K: # we matched an off diagonal point to x_i_j...
+ if x_i_j < nb_off_diag[i]: # ...which is also an off-diagonal point
+ G[y_j, i] = x_i_j
+ else: # ...which is a diagonal point
+ G[y_j, i] = -1 # -1 stands for the diagonal (mask)
+ else: # We matched a diagonal point to x_i_j...
+ if x_i_j < nb_off_diag[i]: # which is a off-diag point ! so we need to create a new point in Y
+ new_y = _mean(np.array([X[i][x_i_j]]), m) # Average this point with (m-1) copies of Delta
+ new_created_points.append(new_y)
+
+ # Step 2 : Update current point position thanks to the groupings computed
+
+ to_delete = []
+ for j in range(K):
+ matched_points = [X[i][G[j, i]] for i in range(m) if G[j, i] > -1]
+ new_y_j = _mean(matched_points, m)
+ if not np.array_equal(new_y_j, np.array([0,0])):
+ updated_points[j] = new_y_j
+ else: # this points is no longer of any use.
+ to_delete.append(j)
+ # we remove the point to be deleted now.
+ updated_points = np.delete(updated_points, to_delete, axis=0) # cannot be done in-place.
+
+
+ if new_created_points: # we cannot converge if there have been new created points.
+ Y = np.concatenate((updated_points, new_created_points))
+ else:
+ # Step 3 : we check convergence
+ if np.array_equal(updated_points, Y):
+ converged = True
+ Y = updated_points
+
+
+ if verbose:
+ groupings = []
+ energy = 0
+ log = {}
+ n_y = len(Y)
+ for i in range(m):
+ edges, cost = _optimal_matching(Y, X[i], withcost=True)
+ print(edges)
+ groupings.append([x_i_j for (y_j, x_i_j) in enumerate(edges) if y_j < n_y])
+ energy += cost
+ log["groupings"] = groupings
+ energy = energy/m
+ log["energy"] = energy
+
+ return Y, log
+ else:
+ return Y
+
diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py
new file mode 100755
index 00000000..910d23ff
--- /dev/null
+++ b/src/python/test/test_wasserstein_barycenter.py
@@ -0,0 +1,33 @@
+from gudhi.barycenter import lagrangian_barycenter
+import numpy as np
+
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Theo Lacombe
+
+ Copyright (C) 2019 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+__author__ = "Theo Lacombe"
+__copyright__ = "Copyright (C) 2019 Inria"
+__license__ = "MIT"
+
+
+def test_lagrangian_barycenter():
+
+ dg1 = np.array([[0.2, 0.5]])
+ dg2 = np.array([[0.2, 0.7]])
+ dg3 = np.array([[0.3, 0.6], [0.7, 0.8], [0.2, 0.3]])
+ dg4 = np.array([])
+ dg5 = np.array([])
+ dg6 = np.array([])
+ res = np.array([[0.27916667, 0.55416667], [0.7375, 0.7625], [0.2375, 0.2625]])
+
+ dg7 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]])
+
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < 0.001
+ assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.empty(shape=(0,2)))
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < 0.001