diff options
author | Romain Tavenard <romain.tavenard@univ-rennes2.fr> | 2019-06-21 11:15:42 +0200 |
---|---|---|
committer | Romain Tavenard <romain.tavenard@univ-rennes2.fr> | 2019-06-21 11:15:42 +0200 |
commit | cada9a3019997e8efb95d96c86985110f1e937b9 (patch) | |
tree | bc1cacefe01438ad44b0c8021bb0052d9190cafd /ot/lp | |
parent | 15b21611a3a93043d30c4eaaf9d622200453a884 (diff) |
Sparse G matrix for EMD1d + standard metrics computed without cdist
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/__init__.py | 54 |
1 files changed, 40 insertions, 14 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index c4457dc..decff29 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -10,6 +10,7 @@ Solvers for the original linear program OT problem import multiprocessing import numpy as np +from scipy.sparse import coo_matrix from .import cvx @@ -19,7 +20,8 @@ from ..utils import parmap from .cvx import barycenter from ..utils import dist -__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d_sorted'] +__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', + 'emd_1d', 'emd2_1d'] def emd(a, b, M, numItermax=100000, log=False): @@ -311,16 +313,20 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X -def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', log=False): +def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False): """Solves the Earth Movers distance problem between 1d measures and returns the OT matrix """ - assert x_a.shape[1] == x_b.shape[1] == 1, "emd_1d should only be used " + \ - "with monodimensional data" - a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) + x_a = np.asarray(x_a, dtype=np.float64) + x_b = np.asarray(x_b, dtype=np.float64) + + assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ + "emd_1d should only be used with monodimensional data" + assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \ + "emd_1d should only be used with monodimensional data" # if empty array given then use uniform distributions if len(a) == 0: @@ -328,16 +334,36 @@ def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', log=False): if len(b) == 0: b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] - perm_a = np.argsort(x_a.reshape((-1, ))) - perm_b = np.argsort(x_b.reshape((-1, ))) - inv_perm_a = np.argsort(perm_a) - inv_perm_b = np.argsort(perm_b) - - G_sorted, cost = emd_1d_sorted(a, b, x_a[perm_a], x_b[perm_b], - metric=metric) - G = G_sorted[inv_perm_a, :][:, inv_perm_b] + x_a_1d = x_a.reshape((-1, )) + x_b_1d = x_b.reshape((-1, )) + perm_a = np.argsort(x_a_1d) + perm_b = np.argsort(x_b_1d) + + G_sorted, indices, cost = emd_1d_sorted(a, b, + x_a_1d[perm_a], x_b_1d[perm_b], + metric=metric) + G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])), + shape=(a.shape[0], b.shape[0])) + if dense: + G = G.todense() if log: log = {} log['cost'] = cost return G, log - return G
\ No newline at end of file + return G + + +def emd2_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False): + """Solves the Earth Movers distance problem between 1d measures and returns + the loss + + """ + # If we do not return G (log==False), then we should not to cast it to dense + # (useless overhead) + G, log_emd = emd_1d(a=a, b=b, x_a=x_a, x_b=x_b, metric=metric, + dense=dense and log, log=True) + cost = log_emd['cost'] + if log: + log_emd = {'G': G} + return cost, log_emd + return cost
\ No newline at end of file |