summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2019-12-19 17:25:01 +0100
committertlacombe <lacombe1993@gmail.com>2019-12-19 17:31:22 +0100
commit180add9067bc9bd0609362717972eeeb8d2f6713 (patch)
tree96558c44732f9296327aada5b760fda01c5a2029 /src
parentd91585af64805a11a4d446d9e3f6467f3394d0c6 (diff)
clean code and doc
Diffstat (limited to 'src')
-rw-r--r--src/python/gudhi/barycenter.py129
1 files changed, 36 insertions, 93 deletions
diff --git a/src/python/gudhi/barycenter.py b/src/python/gudhi/barycenter.py
index 43602a6e..c2173dba 100644
--- a/src/python/gudhi/barycenter.py
+++ b/src/python/gudhi/barycenter.py
@@ -58,12 +58,13 @@ def _build_dist_matrix(X, Y, p=2., q=2.):
return Cf
-def _optimal_matching(X, Y):
+def _optimal_matching(X, Y, withcost=False):
"""
:param X: numpy.array of size (n x 2)
:param Y: numpy.array of size (m x 2)
+ :param withcost: returns also the cost corresponding to this optimal matching
:returns: numpy.array of shape (k x 2) encoding the list of edges in the optimal matching.
- That is, [[(i, j) ...], where (i,j) indicates that X[i] is matched to Y[j]
+ That is, [(i, j) ...], where (i,j) indicates that X[i] is matched to Y[j]
if i > len(X) or j > len(Y), it means they represent the diagonal.
"""
@@ -74,10 +75,10 @@ def _optimal_matching(X, Y):
if Y.size == 0: # Y is empty
return np.array([[0,0]]) # the diagonal is matched to the diagonal and that's it...
else:
- return np.column_stack([np.zeros(m+1, dtype=int), np.arange(m+1, dtype=int)]) # TO BE CORRECTED
+ return np.column_stack([np.zeros(m+1, dtype=int), np.arange(m+1, dtype=int)])
elif Y.size == 0: # X is not empty but Y is empty
- return np.column_stack([np.zeros(n+1, dtype=int), np.arange(n+1, dtype=int)]) # TO BE CORRECTED
-
+ return np.column_stack([np.zeros(n+1, dtype=int), np.arange(n+1, dtype=int)])
+
# we know X, Y are not empty diags now
M = _build_dist_matrix(X, Y)
@@ -86,12 +87,16 @@ def _optimal_matching(X, Y):
b = np.full(m+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here.
b[-1] = b[-1] * n # so that we have a probability measure, required by POT
P = ot.emd(a=a, b=b, M=M)*(n+m)
- # Note : it seems POT return a permutation matrix in this situation,
- # ...guarantee...?
- # It should be enough to check that the algorithm only iterates on vertices of the transportation polytope.
+ # Note : it seems POT return a permutation matrix in this situation, ie a vertex of the constraint set (generically true).
+ if withcost:
+ cost = np.sqrt(np.sum(np.multiply(P, M)))
P[P < 0.5] = 0 # dirty trick to avoid some numerical issues... to be improved.
# return the list of (i,j) such that P[i,j] > 0, i.e. x_i is matched to y_j (should it be the diag).
res = np.nonzero(P)
+
+ if withcost:
+ return np.column_stack(res), cost
+
return np.column_stack(res)
@@ -123,13 +128,16 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
Otherwise, it must be an int (then we init with diagset[init])
or a (n x 2) numpy.array enconding a persistence diagram with n points.
:param verbose: if True, returns additional information about the
- barycenters (assignment and energy).
+ barycenter.
:returns: If not verbose (default), a numpy.array encoding
the barycenter estimate (local minima of the energy function).
- If verbose, returns a triplet (Y, a, e)
- where Y is the barycenter estimate, a is the assignments between the
- points of Y and thoses of the diagrams,
- and e is the energy value reached by the estimate.
+ If verbose, returns a couple (Y, log)
+ where Y is the barycenter estimate,
+ and log is a dict that contains additional informations:
+ - assigments, a list of list of pairs (i,j),
+ That is, a[k] = [(i, j) ...], where (i,j) indicates that X[i] is matched to Y[j]
+ if i > len(X) or j > len(Y), it means they represent the diagonal.
+ - energy, a float representing the Frechet mean value obtained.
"""
X = pdiagset # to shorten notations, not a copy
m = len(X) # number of diagrams we are averaging
@@ -200,25 +208,29 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
if verbose:
- matchings = []
- #energy = 0
+ groupings = []
+ energy = 0
+ log = {}
n_y = len(Y)
for i in range(m):
- edges = _optimal_matching(Y, X[i])
- matchings.append([x_i_j for (y_j, x_i_j) in enumerate(edges) if y_j < n_y])
- # energy += sum([M[i,j] for i,j in enumerate(edges)])
-
- # energy = energy/m
- return Y, matchings #, energy
+ edges, cost = _optimal_matching(Y, X[i], withcost=True)
+ print(edges)
+ groupings.append([x_i_j for (y_j, x_i_j) in enumerate(edges) if y_j < n_y])
+ energy += cost
+ log["groupings"] = groupings
+ energy = energy/m
+ log["energy"] = energy
+
+ return Y, log
else:
return Y
-def _plot_barycenter(X, Y, matchings):
+def _plot_barycenter(X, Y, groupings):
"""
:param X: list of persistence diagrams.
:param Y: numpy.array of (n x 2). Aims to be an estimate of the barycenter
returned by lagrangian_barycenter(X, verbose=True).
- :param matchings: list of lists, such that L[k][i] = j if and only if
+ :param groupings: list of lists, such that L[k][i] = j if and only if
the i-th point of the barycenter is grouped with the j-th point of the k-th
diagram.
"""
@@ -232,7 +244,7 @@ def _plot_barycenter(X, Y, matchings):
# n_y = len(Y.points)
for i in range(len(X)):
- indices = matchings[i]
+ indices = groupings[i]
n_i = len(X[i])
for (y_j, x_i_j) in indices:
@@ -271,72 +283,3 @@ def _plot_barycenter(X, Y, matchings):
plt.show()
-
-def _test_perf():
- nb_repeat = 10
- nb_points_in_dgm = [5, 10, 20, 50, 100]
- nb_dmg = [3, 5, 10, 20]
-
- from time import time
- for m in nb_dmg:
- for n in nb_points_in_dgm:
- tstart = time()
- for _ in range(nb_repeat):
- X = [np.random.rand(n, 2) for _ in range(m)]
- for diag in X:
- #enforce having diagrams
- diag[:,1] = diag[:,1] + diag[:,0]
- _ = lagrangian_barycenter(X)
- tend = time()
- print("Computation of barycenter in %s sec, with k = %s diags and n = %s points per diag."%(np.round((tend - tstart)/nb_repeat, 2), m, n))
- print("********************")
-
-
-def _sanity_check(verbose):
- #dg1 = np.array([[0.2, 0.5]])
- #dg2 = np.array([[0.2, 0.7], [0.73, 0.88]])
- #dg3 = np.array([[0.3, 0.6], [0.7, 0.85], [0.2, 0.3]])
- #X = [dg1, dg2, dg3]
- #Y, a = lagrangian_barycenter(X, verbose=verbose)
- #_plot_barycenter(X, Y, a)
-
- #dg1 = np.array([[0.2, 0.5]])
- #dg2 = np.array([]) # The empty diagram
- #dg3 = np.array([[0.4, 0.8]])
- #X = [dg1, dg2, dg3]
- #Y, a = lagrangian_barycenter(X, verbose=verbose)
- #_plot_barycenter(X, Y, a)
-
- #dg1 = np.array([])
- #dg2 = np.array([]) # The empty diagram
- #dg3 = np.array([])
- #X = [dg1, dg2, dg3]
- #Y, a = lagrangian_barycenter(X, verbose=verbose)
- #_plot_barycenter(X, Y, a)
- #print(Y)
-
- dg1 = np.array([[0.1, 0.12], [0.21, 0.7], [0.4, 0.5], [0.3, 0.4], [0.35, 0.7], [0.5, 0.55], [0.32, 0.42], [0.1, 0.4], [0.2, 0.4]])
- dg2 = np.array([[0.09, 0.11], [0.3, 0.43], [0.5, 0.61], [0.3, 0.7], [0.42, 0.5], [0.35, 0.41], [0.74, 0.9], [0.5, 0.95], [0.35, 0.45], [0.13, 0.48], [0.32, 0.45]])
- dg3 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]])
- dg4 = np.array([])
- X = [dg4]
- Y, a = lagrangian_barycenter(X, verbose=verbose)
- #_plot_barycenter(X, Y, a)
- print(Y)
- print(np.array_equal(Y, np.empty(shape=(0,2) )))
-
-
- #dg1 = np.array([[0.2, 0.5]])
- #dg2 = np.array([[0.2, 0.7]])
- #dg3 = np.array([[0.3, 0.6], [0.7, 0.8], [0.2, 0.3]])
- #dg4 = np.array([])
- #
- #bary, a = lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=True)
- #_plot_barycenter([dg1, dg2, dg3, dg4], bary, a)
- #message = "Wasserstein barycenter estimated:"
- #print(message)
- #print(bary)
-
-if __name__=="__main__":
- _sanity_check(verbose = True)
- #_test_perf()