diff options
author | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-07 11:43:49 +0200 |
---|---|---|
committer | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-07 11:43:49 +0200 |
commit | bf9b170ded15d083efb73ed8c4e3e8fd23796211 (patch) | |
tree | a6b5b387525b4a166ee055b09ffd47381e6b864b /ot/lp | |
parent | b32c81542c99cc48944fbeb13e4648f9947ac19d (diff) | |
parent | d399f62ec480b0ad46c6e35957543eeb0738854a (diff) |
Merge branch 'master' into laplace_da
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/__init__.py | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index c4b5834..8d1baa0 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -10,16 +10,16 @@ Solvers for the original linear program OT problem import multiprocessing import sys + import numpy as np from scipy.sparse import coo_matrix -from .import cvx - +from . import cvx +from .cvx import barycenter # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from ..utils import parmap -from .cvx import barycenter from ..utils import dist +from ..utils import parmap __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] @@ -456,7 +456,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), return res -def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None): +def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, + stopThr=1e-7, verbose=False, log=None): """ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) @@ -523,8 +524,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None T_sum = np.zeros((k, d)) - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): - + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, + weights.tolist()): M_i = dist(X, measure_locations_i) T_i = emd(b, measure_weights_i, M_i) T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) @@ -649,12 +650,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, if b.ndim == 0 or len(b) == 0: b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] - x_a_1d = x_a.reshape((-1, )) - x_b_1d = x_b.reshape((-1, )) + x_a_1d = x_a.reshape((-1,)) + x_b_1d = x_b.reshape((-1,)) perm_a = np.argsort(x_a_1d) perm_b = np.argsort(x_b_1d) - G_sorted, indices, cost = emd_1d_sorted(a, b, + G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b], x_a_1d[perm_a], x_b_1d[perm_b], metric=metric, p=p) G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])), |