summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-02-17 17:50:37 +0100
committertlacombe <lacombe1993@gmail.com>2020-02-17 17:50:37 +0100
commit5e4bc93510f50dacdb59f1a7578aca72817c9631 (patch)
tree93b6ff73c6f9eb7851c620e0575fc3f391b803a0 /src
parentf8fe3fdb01f6161b57da732a1c3f0c14a8b359a6 (diff)
update doc + removed normalization + use argwhere
Diffstat (limited to 'src')
-rw-r--r--src/python/doc/barycenter_user.rst7
-rw-r--r--src/python/gudhi/barycenter.py29
2 files changed, 18 insertions, 18 deletions
diff --git a/src/python/doc/barycenter_user.rst b/src/python/doc/barycenter_user.rst
index 714d807e..f81e9358 100644
--- a/src/python/doc/barycenter_user.rst
+++ b/src/python/doc/barycenter_user.rst
@@ -9,7 +9,8 @@ Definition
.. include:: barycenter_sum.inc
-This implementation is based on ideas from "Frechet means for distribution of persistence diagrams", Turner et al. 2014.
+This implementation is based on ideas from "Frechet means for distribution of
+persistence diagrams", Turner et al. 2014.
Function
--------
@@ -21,6 +22,10 @@ Basic example
This example computes the Frechet mean (aka Wasserstein barycenter) between four persistence diagrams.
It is initialized on the 4th diagram, which is the empty diagram. It is encoded by np.array([]).
+As the algorithm is not convex, its output depends on the initialization and is only a local minimum of the objective function.
+Initialization can be either given as an integer (in which case the i-th diagram of the list is used as initial estimate)
+or as a diagram.
+If None, it will randomly select one of the diagram of the list as initial estimate.
Note that persistence diagrams must be submitted as (n x 2) numpy arrays and must not contain inf values.
.. testcode::
diff --git a/src/python/gudhi/barycenter.py b/src/python/gudhi/barycenter.py
index 4a877b4a..c54066ec 100644
--- a/src/python/gudhi/barycenter.py
+++ b/src/python/gudhi/barycenter.py
@@ -15,12 +15,6 @@ import scipy.spatial.distance as sc
from gudhi.wasserstein import _build_dist_matrix, _perstot
-def _proj_on_diag(w):
- '''
- Util function to project a point on the diag.
- '''
- return np.array([(w[0] + w[1])/2 , (w[0] + w[1])/2])
-
def _mean(x, m):
"""
@@ -32,7 +26,7 @@ def _mean(x, m):
k = len(x)
if k > 0:
w = np.mean(x, axis=0)
- w_delta = _proj_on_diag(w)
+ w_delta = (w[0] + w[1]) / 2 * np.ones(2)
return (k * w + (m-k) * w_delta) / m
else:
return np.array([0, 0])
@@ -80,31 +74,32 @@ def _optimal_matching(X, Y, withcost=False):
# we know X, Y are not empty diags now
M = _build_dist_matrix(X, Y, order=2, internal_p=2)
- a = np.full(n+1, 1. / (n + m) )
- a[-1] = a[-1] * m
- b = np.full(m+1, 1. / (n + m) )
- b[-1] = b[-1] * n
- P = ot.emd(a=a, b=b, M=M)*(n+m)
+ a = np.ones(n+1)
+ a[-1] = m
+ b = np.ones(m+1)
+ b[-1] = n
+ P = ot.emd(a=a, b=b, M=M)
# Note : it seems POT returns a permutation matrix in this situation,
# ie a vertex of the constraint set (generically true).
if withcost:
cost = np.sum(np.multiply(P, M))
P[P < 0.5] = 0 # dirty trick to avoid some numerical issues... to improve.
- res = np.nonzero(P)
+ res = np.argwhere(P)
# return the list of (i,j) such that P[i,j] > 0,
#i.e. x_i is matched to y_j (should it be the diag).
if withcost:
- return np.column_stack(res), cost
-
- return np.column_stack(res)
+ return res, cost
+ return res
def lagrangian_barycenter(pdiagset, init=None, verbose=False):
"""
- Compute the estimated barycenter computed with the algorithm provided
+ Returns the estimated barycenter computed with the algorithm provided
by Turner et al (2014).
+ As the algorithm is not convex, the output depends on initialization.
It is a local minimum of the corresponding Frechet function.
+
:param pdiagset: a list of size m containing numpy.array of shape (n x 2)
(n can variate), encoding a set of
persistence diagrams with only finite coordinates.