summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-03-17 10:55:14 +0100
committertlacombe <lacombe1993@gmail.com>2020-03-17 10:55:14 +0100
commitcdc57712ca159f3044453cef41e31ebc03617a1b (patch)
tree671e6527ac7c0b3e3ee2a0d14f1d232d97634699
parent2de9709b63045c484aa1c53f72c870eb210880d9 (diff)
removed _optimal_matching from barycenter as it is now handled by wasserstein_distance.
-rw-r--r--src/python/gudhi/barycenter.py85
-rwxr-xr-xsrc/python/test/test_wasserstein_barycenter.py2
2 files changed, 9 insertions, 78 deletions
diff --git a/src/python/gudhi/barycenter.py b/src/python/gudhi/barycenter.py
index 517cdb2f..0490fdd1 100644
--- a/src/python/gudhi/barycenter.py
+++ b/src/python/gudhi/barycenter.py
@@ -12,8 +12,7 @@ import ot
import numpy as np
import scipy.spatial.distance as sc
-from gudhi.wasserstein import _build_dist_matrix, _perstot
-
+from gudhi.wasserstein import wasserstein_distance, _perstot
def _mean(x, m):
@@ -32,70 +31,6 @@ def _mean(x, m):
return np.array([0, 0])
-def _optimal_matching(X, Y, withcost=False):
- '''
- :param X: numpy.array of size (n x 2)
- :param Y: numpy.array of size (m x 2)
- :param withcost: returns also the cost corresponding to the optimal matching
- :returns: numpy.array of shape (k x 2) encoding the list of edges
- in the optimal matching.
- That is, [[i, j] ...], where (i,j) indicates
- that X[i] is matched to Y[j]
- if i >= len(X) or j >= len(Y), it means they
- represent the diagonal.
- They will be encoded by -1 afterwards.
-
- NOTE : this code will be removed for final merge,
- and wasserstein.optimal_matching will be used instead.
- '''
-
- n = len(X)
- m = len(Y)
- # Start by handling empty diagrams. Could it be shorten?
- if X.size == 0: # X is empty
- if Y.size == 0: # Y is empty
- res = np.array([[0,0]]) # the diagonal is matched to the diagonal
- if withcost:
- return res, 0
- else:
- return res
- else: # X is empty but not Y
- res = np.array([[0, i] for i in range(m)])
- cost = _perstot(Y, order=2, internal_p=2)**2
- if withcost:
- return res, cost
- else:
- return res
- elif Y.size == 0: # X is not empty but Y is empty
- res = np.array([[i,0] for i in range(n)])
- cost = _perstot(X, order=2, internal_p=2)**2
- if withcost:
- return res, cost
- else:
- return res
-
- # we know X, Y are not empty diags now
- M = _build_dist_matrix(X, Y, order=2, internal_p=2)
-
- a = np.ones(n+1)
- a[-1] = m
- b = np.ones(m+1)
- b[-1] = n
- P = ot.emd(a=a, b=b, M=M)
- # Note : it seems POT returns a permutation matrix in this situation,
- # ie a vertex of the constraint set (generically true).
- if withcost:
- cost = np.sum(np.multiply(P, M))
- P[P < 0.5] = 0 # dirty trick to avoid some numerical issues... to improve.
- res = np.argwhere(P)
-
- # return the list of (i,j) such that P[i,j] > 0,
- #i.e. x_i is matched to y_j (should it be the diag).
- if withcost:
- return res, cost
- return res
-
-
def lagrangian_barycenter(pdiagset, init=None, verbose=False):
'''
:param pdiagset: a list of size m containing numpy.array of shape (n x 2)
@@ -166,16 +101,15 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
# Step 1 : compute optimal matching (Y, X_i) for each X_i
# and create new points in Y if needed
for i in range(m):
- indices = _optimal_matching(Y, X[i])
+ _, indices = wasserstein_distance(Y, X[i], matching=True, order=2., internal_p=2.)
for y_j, x_i_j in indices:
- if y_j < K: # we matched an off diagonal point to x_i_j...
- # ...which is also an off-diagonal point.
- if x_i_j < nb_off_diag[i]:
+ if y_j >= 0: # we matched an off diagonal point to x_i_j...
+ if x_i_j >= 0: # ...which is also an off-diagonal point.
G[y_j, i] = x_i_j
else: # ...which is a diagonal point
G[y_j, i] = -1 # -1 stands for the diagonal (mask)
else: # We matched a diagonal point to x_i_j...
- if x_i_j < nb_off_diag[i]: # which is a off-diag point !
+ if x_i_j >= 0: # which is a off-diag point !
# need to create new point in Y
new_y = _mean(np.array([X[i][x_i_j]]), m)
# Average this point with (m-1) copies of Delta
@@ -209,15 +143,12 @@ def lagrangian_barycenter(pdiagset, init=None, verbose=False):
log = {}
n_y = len(Y)
for i in range(m):
- edges, cost = _optimal_matching(Y, X[i], withcost=True)
- n_x = len(X[i])
- G = edges[np.where(edges[:,0]<n_y)]
- idx = np.where(G[:,1] >= n_x)
- G[idx,1] = -1 # -1 will encode the diagonal
- groupings.append(G)
+ cost, edges = wasserstein_distance(Y, X[i], matching=True, order=2., internal_p=2.)
+ groupings.append(edges)
energy += cost
log["groupings"] = groupings
energy = energy/m
+ print(energy)
log["energy"] = energy
log["nb_iter"] = nb_iter
diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py
index 5167cb84..4d18616b 100755
--- a/src/python/test/test_wasserstein_barycenter.py
+++ b/src/python/test/test_wasserstein_barycenter.py
@@ -38,7 +38,7 @@ def test_lagrangian_barycenter():
assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < eps
Y, log = lagrangian_barycenter(pdiagset=[dg4, dg8], verbose=True)
assert np.linalg.norm(Y - np.array([[1,3], [5, 7]])) < eps
- assert np.abs(log["energy"] - 4) < eps
+ assert np.abs(log["energy"] - 2) < eps
assert np.array_equal(log["groupings"][0] , np.array([[0, -1], [1, -1]]))
assert np.array_equal(log["groupings"][1] , np.array([[0, 0], [1, 1]]))
assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg8, dg4], init=np.array([[0.2, 0.6], [0.5, 0.7]]), verbose=False) - np.array([[1, 3], [5, 7]])) < eps