summaryrefslogtreecommitdiff
path: root/src/python/gudhi
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-02-17 17:50:37 +0100
committertlacombe <lacombe1993@gmail.com>2020-02-17 17:50:37 +0100
commit5e4bc93510f50dacdb59f1a7578aca72817c9631 (patch)
tree93b6ff73c6f9eb7851c620e0575fc3f391b803a0 /src/python/gudhi
parentf8fe3fdb01f6161b57da732a1c3f0c14a8b359a6 (diff)
update doc + removed normalization + use argwhere
Diffstat (limited to 'src/python/gudhi')
-rw-r--r--src/python/gudhi/barycenter.py29
1 files changed, 12 insertions, 17 deletions
diff --git a/src/python/gudhi/barycenter.py b/src/python/gudhi/barycenter.py
index 4a877b4a..c54066ec 100644
--- a/src/python/gudhi/barycenter.py
+++ b/src/python/gudhi/barycenter.py
@@ -15,12 +15,6 @@ import scipy.spatial.distance as sc
from gudhi.wasserstein import _build_dist_matrix, _perstot
-def _proj_on_diag(w):
- '''
- Util function to project a point on the diag.
- '''
- return np.array([(w[0] + w[1])/2 , (w[0] + w[1])/2])
-
def _mean(x, m):
"""
@@ -32,7 +26,7 @@ def _mean(x, m):
k = len(x)
if k > 0:
w = np.mean(x, axis=0)
- w_delta = _proj_on_diag(w)
+ w_delta = (w[0] + w[1]) / 2 * np.ones(2)
return (k * w + (m-k) * w_delta) / m
else:
return np.array([0, 0])
@@ -80,31 +74,32 @@ def _optimal_matching(X, Y, withcost=False):
# we know X, Y are not empty diags now
M = _build_dist_matrix(X, Y, order=2, internal_p=2)
- a = np.full(n+1, 1. / (n + m) )
- a[-1] = a[-1] * m
- b = np.full(m+1, 1. / (n + m) )
- b[-1] = b[-1] * n
- P = ot.emd(a=a, b=b, M=M)*(n+m)
+ a = np.ones(n+1)
+ a[-1] = m
+ b = np.ones(m+1)
+ b[-1] = n
+ P = ot.emd(a=a, b=b, M=M)
# Note : it seems POT returns a permutation matrix in this situation,
# ie a vertex of the constraint set (generically true).
if withcost:
cost = np.sum(np.multiply(P, M))
P[P < 0.5] = 0 # dirty trick to avoid some numerical issues... to improve.
- res = np.nonzero(P)
+ res = np.argwhere(P)
# return the list of (i,j) such that P[i,j] > 0,
#i.e. x_i is matched to y_j (should it be the diag).
if withcost:
- return np.column_stack(res), cost
-
- return np.column_stack(res)
+ return res, cost
+ return res
def lagrangian_barycenter(pdiagset, init=None, verbose=False):
"""
- Compute the estimated barycenter computed with the algorithm provided
+ Returns the estimated barycenter computed with the algorithm provided
by Turner et al (2014).
+ As the algorithm is not convex, the output depends on initialization.
It is a local minimum of the corresponding Frechet function.
+
:param pdiagset: a list of size m containing numpy.array of shape (n x 2)
(n can variate), encoding a set of
persistence diagrams with only finite coordinates.