summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-03-16 13:16:04 +0100
committertlacombe <lacombe1993@gmail.com>2020-03-16 13:16:04 +0100
commit4b546a43fe14178dcfb2b327e27a580fc9811499 (patch)
tree42376206ef6750113f76a828d1dfeb7b3bfc48eb
parent16e04f4f14cc73956064aef283e29134f00d306c (diff)
update doc (indentation, mention of -1 for the diag) and added a few more tests
-rw-r--r--src/python/gudhi/barycenter.py30
-rwxr-xr-xsrc/python/test/test_wasserstein_barycenter.py15
2 files changed, 23 insertions, 22 deletions
diff --git a/src/python/gudhi/barycenter.py b/src/python/gudhi/barycenter.py
index a41b5906..3af12c14 100644
--- a/src/python/gudhi/barycenter.py
+++ b/src/python/gudhi/barycenter.py
@@ -96,9 +96,8 @@ def _optimal_matching(X, Y, withcost=False):
def lagrangian_barycenter(pdiagset, init=None, verbose=False):
'''
:param pdiagset: a list of size m containing numpy.array of shape (n x 2)
- (n can variate), encoding a set of
- persistence diagrams with only finite coordinates.
- If empty, returns None.
+ (n can variate), encoding a set of
+ persistence diagrams with only finite coordinates.
:param init: The initial value for barycenter estimate.
If None, init is made on a random diagram from the dataset.
Otherwise, it must be an int
@@ -106,24 +105,25 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
or a (n x 2) numpy.array enconding
a persistence diagram with n points.
:param verbose: if True, returns additional information about the
- barycenter.
+ barycenter.
:returns: If not verbose (default), a numpy.array encoding
- the barycenter estimate
+ the barycenter estimate of pdiagset
(local minima of the energy function).
+ If pdiagset is empty, returns None.
If verbose, returns a couple (Y, log)
where Y is the barycenter estimate,
and log is a dict that contains additional informations:
- groupings, a list of list of pairs (i,j),
- That is, G[k] = [(i, j) ...], where (i,j) indicates
- that X[i] is matched to Y[j]
- if i > len(X) or j > len(Y), it means they
- represent the diagonal.
+ That is, G[k] = [(i, j) ...], where (i,j) indicates
+ that X[i] is matched to Y[j]
+ if i = -1 or j = -1, it means they
+ represent the diagonal.
- energy, a float representing the Frechet
- energy value obtained,
- that is the mean of squared distances
- of observations to the output.
+ energy value obtained,
+ that is the mean of squared distances
+ of observations to the output.
- nb_iter, integer representing the number of iterations
- performed before convergence of the algorithm.
+ performed before convergence of the algorithm.
'''
X = pdiagset # to shorten notations, not a copy
m = len(X) # number of diagrams we are averaging
@@ -136,7 +136,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
# Initialisation of barycenter
if init is None:
i0 = np.random.randint(m) # Index of first state for the barycenter
- Y = X[i0].copy() #copy() ensure that we do not modify X[i0]
+ Y = X[i0].copy()
else:
if type(init)==int:
Y = X[init].copy()
@@ -149,7 +149,7 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
while not converged:
nb_iter += 1
K = len(Y) # current nb of points in Y (some might be on diagonal)
- G = np.zeros((K, m), dtype=int)-1 # will store for each j, the (index)
+ G = np.full((K, m), -1, dtype=int) # will store for each j, the (index)
# point matched in each other diagram
#(might be the diagonal).
# that is G[j, i] = k <=> y_j is matched to
diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py
index a58a4d62..5167cb84 100755
--- a/src/python/test/test_wasserstein_barycenter.py
+++ b/src/python/test/test_wasserstein_barycenter.py
@@ -27,19 +27,20 @@ def test_lagrangian_barycenter():
res = np.array([[0.27916667, 0.55416667], [0.7375, 0.7625], [0.2375, 0.2625]])
dg7 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]])
- dg8 = np.array([[0., 4.]])
+ dg8 = np.array([[0., 4.], [4, 8]])
# error crit.
- eps = 0.000001
+ eps = 1e-7
assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < eps
assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.empty(shape=(0,2)))
assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < eps
Y, log = lagrangian_barycenter(pdiagset=[dg4, dg8], verbose=True)
- assert np.linalg.norm(Y - np.array([[1,3]])) < eps
- assert np.abs(log["energy"] - 2) < eps
- assert np.array_equal(log["groupings"][0] , np.array([[0, -1]]))
- assert np.array_equal(log["groupings"][1] , np.array([[0, 0]]))
- assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg8, dg4], init=np.array([[0.2, 0.6], [0.5, 0.7]]), verbose=False) - np.array([[1, 3]])) < eps
+ assert np.linalg.norm(Y - np.array([[1,3], [5, 7]])) < eps
+ assert np.abs(log["energy"] - 4) < eps
+ assert np.array_equal(log["groupings"][0] , np.array([[0, -1], [1, -1]]))
+ assert np.array_equal(log["groupings"][1] , np.array([[0, 0], [1, 1]]))
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg8, dg4], init=np.array([[0.2, 0.6], [0.5, 0.7]]), verbose=False) - np.array([[1, 3], [5, 7]])) < eps
assert lagrangian_barycenter(pdiagset = []) is None
+