From dc4442bc402ac25290eb529b57407607434bb7ae Mon Sep 17 00:00:00 2001 From: tlacombe Date: Fri, 14 Feb 2020 14:53:51 +0100 Subject: barycenter update, adding more tests and details about log (assigments, cost, nb iter) --- src/python/gudhi/barycenter.py | 125 +++++++++++-------------- src/python/test/test_wasserstein_barycenter.py | 15 ++- 2 files changed, 69 insertions(+), 71 deletions(-) diff --git a/src/python/gudhi/barycenter.py b/src/python/gudhi/barycenter.py index 11098afe..4a00c457 100644 --- a/src/python/gudhi/barycenter.py +++ b/src/python/gudhi/barycenter.py @@ -2,6 +2,7 @@ import ot import numpy as np import scipy.spatial.distance as sc +from wasserstein import _build_dist_matrix, _perstot # This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. # See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. @@ -20,42 +21,19 @@ def _proj_on_diag(w): return np.array([(w[0] + w[1])/2 , (w[0] + w[1])/2]) -def _proj_on_diag_array(X): - ''' - :param X: (n x 2) array encoding the points of a persistent diagram. - :returns: (n x 2) array encoding the (respective orthogonal) projections of the points onto the diagonal - ''' - Z = (X[:,0] + X[:,1]) / 2. - return np.array([Z , Z]).T - - -def _build_dist_matrix(X, Y, p=2., q=2.): - ''' - :param X: (n x 2) numpy.array encoding the (points of the) first diagram. - :param Y: (m x 2) numpy.array encoding the second diagram. - :param q: Ground metric (i.e. norm l_q). - :param p: exponent for the Wasserstein metric. - :returns: (n+1) x (m+1) np.array encoding the cost matrix C. - For 1 <= i <= n, 1 <= j <= m, C[i,j] encodes the distance between X[i] and Y[j], while C[i, m+1] (resp. C[n+1, j]) encodes the distance (to the p) between X[i] (resp Y[j]) and its orthogonal proj onto the diagonal. - note also that C[n+1, m+1] = 0 (it costs nothing to move from the diagonal to the diagonal). - Note that for lagrangian_barycenter, one must use p=q=2. - ''' - Xdiag = _proj_on_diag_array(X) - Ydiag = _proj_on_diag_array(Y) - if np.isinf(q): - C = sc.cdist(X, Y, metric='chebyshev')**p - Cxd = np.linalg.norm(X - Xdiag, ord=q, axis=1)**p - Cdy = np.linalg.norm(Y - Ydiag, ord=q, axis=1)**p +def _mean(x, m): + """ + :param x: a list of 2D-points, off diagonal, x_0... x_{k-1} + :param m: total amount of points taken into account, that is we have (m-k) copies of diagonal + :returns: the weighted mean of x with (m-k) copies of the diagonal + """ + k = len(x) + if k > 0: + w = np.mean(x, axis=0) + w_delta = _proj_on_diag(w) + return (k * w + (m-k) * w_delta) / m else: - C = sc.cdist(X,Y, metric='minkowski', p=q)**p - Cxd = np.linalg.norm(X - Xdiag, ord=q, axis=1)**p - Cdy = np.linalg.norm(Y - Ydiag, ord=q, axis=1)**p - Cf = np.hstack((C, Cxd[:,None])) - Cdy = np.append(Cdy, 0) - - Cf = np.vstack((Cf, Cdy[None,:])) - - return Cf + return np.array([0, 0]) def _optimal_matching(X, Y, withcost=False): @@ -64,63 +42,63 @@ def _optimal_matching(X, Y, withcost=False): :param Y: numpy.array of size (m x 2) :param withcost: returns also the cost corresponding to this optimal matching :returns: numpy.array of shape (k x 2) encoding the list of edges in the optimal matching. - That is, [(i, j) ...], where (i,j) indicates that X[i] is matched to Y[j] - if i > len(X) or j > len(Y), it means they represent the diagonal. - + That is, [[i, j] ...], where (i,j) indicates that X[i] is matched to Y[j] + if i >= len(X) or j >= len(Y), it means they represent the diagonal. + They will be encoded by -1 afterwards. """ n = len(X) m = len(Y) + # Start by handling empty diagrams. Could it be shorten? if X.size == 0: # X is empty if Y.size == 0: # Y is empty - return np.array([[0,0]]) # the diagonal is matched to the diagonal and that's it... - else: - return np.column_stack([np.zeros(m+1, dtype=int), np.arange(m+1, dtype=int)]) + res = np.array([[0,0]]) # the diagonal is matched to the diagonal and that's it... + if withcost: + return res, 0 + else: + return res + else: # X is empty but not Y + res = np.array([[0, i] for i in range(m)]) + cost = _perstot(Y, order=2, internal_p=2)**2 + if withcost: + return res, cost + else: + return res elif Y.size == 0: # X is not empty but Y is empty - return np.column_stack([np.zeros(n+1, dtype=int), np.arange(n+1, dtype=int)]) - + res = np.array([[i,0] for i in range(n)]) + cost = _perstot(X, order=2, internal_p=2)**2 + if withcost: + return res, cost + else: + return res + # we know X, Y are not empty diags now - M = _build_dist_matrix(X, Y) + M = _build_dist_matrix(X, Y, order=2, internal_p=2) a = np.full(n+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here. a[-1] = a[-1] * m # normalized so that we have a probability measure, required by POT b = np.full(m+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here. b[-1] = b[-1] * n # so that we have a probability measure, required by POT P = ot.emd(a=a, b=b, M=M)*(n+m) - # Note : it seems POT return a permutation matrix in this situation, ie a vertex of the constraint set (generically true). + # Note : it seems POT returns a permutation matrix in this situation, ie a vertex of the constraint set (generically true). if withcost: - cost = np.sqrt(np.sum(np.multiply(P, M))) + cost = np.sum(np.multiply(P, M)) P[P < 0.5] = 0 # dirty trick to avoid some numerical issues... to be improved. - # return the list of (i,j) such that P[i,j] > 0, i.e. x_i is matched to y_j (should it be the diag). res = np.nonzero(P) + # return the list of (i,j) such that P[i,j] > 0, i.e. x_i is matched to y_j (should it be the diag). if withcost: return np.column_stack(res), cost return np.column_stack(res) -def _mean(x, m): - """ - :param x: a list of 2D-points, off diagonal, x_0... x_{k-1} - :param m: total amount of points taken into account, that is we have (m-k) copies of diagonal - :returns: the weighted mean of x with (m-k) copies of the diagonal - """ - k = len(x) - if k > 0: - w = np.mean(x, axis=0) - w_delta = _proj_on_diag(w) - return (k * w + (m-k) * w_delta) / m - else: - return np.array([0, 0]) - - def lagrangian_barycenter(pdiagset, init=None, verbose=False): """ Compute the estimated barycenter computed with the algorithm provided by Turner et al (2014). It is a local minimum of the corresponding Frechet function. - :param pdiagset: a list of size N containing numpy.array of shape (n x 2) + :param pdiagset: a list of size m containing numpy.array of shape (n x 2) (n can variate), encoding a set of persistence diagrams with only finite coordinates. :param init: The initial value for barycenter estimate. @@ -134,10 +112,13 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False): If verbose, returns a couple (Y, log) where Y is the barycenter estimate, and log is a dict that contains additional informations: - - assigments, a list of list of pairs (i,j), - That is, a[k] = [(i, j) ...], where (i,j) indicates that X[i] is matched to Y[j] + - groupings, a list of list of pairs (i,j), + That is, G[k] = [(i, j) ...], where (i,j) indicates that X[i] is matched to Y[j] if i > len(X) or j > len(Y), it means they represent the diagonal. - - energy, a float representing the Frechet mean value obtained. + - energy, a float representing the Frechet energy value obtained, + that is the mean of squared distances of observations to the output. + - nb_iter, integer representing the number of iterations performed before convergence + of the algorithm. """ X = pdiagset # to shorten notations, not a copy m = len(X) # number of diagrams we are averaging @@ -157,8 +138,11 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False): else: Y = init.copy() + nb_iter = 0 + converged = False # stoping criterion while not converged: + nb_iter += 1 K = len(Y) # current nb of points in Y (some might be on diagonal) G = np.zeros((K, m), dtype=int)-1 # will store for each j, the (index) point matched in each other diagram (might be the diagonal). # that is G[j, i] = k <=> y_j is matched to @@ -185,7 +169,6 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False): new_created_points.append(new_y) # Step 2 : Update current point position thanks to the groupings computed - to_delete = [] for j in range(K): matched_points = [X[i][G[j, i]] for i in range(m) if G[j, i] > -1] @@ -214,12 +197,16 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False): n_y = len(Y) for i in range(m): edges, cost = _optimal_matching(Y, X[i], withcost=True) - print(edges) - groupings.append([x_i_j for (y_j, x_i_j) in enumerate(edges) if y_j < n_y]) + n_x = len(X[i]) + G = edges[np.where(edges[:,0]= n_x) + G[idx,1] = -1 # -1 will encode the diagonal + groupings.append(G) energy += cost log["groupings"] = groupings energy = energy/m log["energy"] = energy + log["nb_iter"] = nb_iter return Y, log else: diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py index 910d23ff..07242582 100755 --- a/src/python/test/test_wasserstein_barycenter.py +++ b/src/python/test/test_wasserstein_barycenter.py @@ -27,7 +27,18 @@ def test_lagrangian_barycenter(): res = np.array([[0.27916667, 0.55416667], [0.7375, 0.7625], [0.2375, 0.2625]]) dg7 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]]) + dg8 = np.array([[0., 4.]]) + + # error crit. + eps = 0.000001 - assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < 0.001 + + assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < eps assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.empty(shape=(0,2))) - assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < 0.001 + assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < eps + Y, log = lagrangian_barycenter(pdiagset=[dg4, dg8], verbose=True) + assert np.linalg.norm(Y - np.array([[1,3]])) < eps + assert np.abs(log["energy"] - 2) < eps + assert np.array_equal(log["groupings"][0] , np.array([[0, -1]])) + assert np.array_equal(log["groupings"][1] , np.array([[0, 0]])) + assert lagrangian_barycenter(pdiagset = []) is None -- cgit v1.2.3