summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
authorAdrienCorenflos <adrien.corenflos@gmail.com>2020-04-02 10:41:24 +0100
committerAdrienCorenflos <adrien.corenflos@gmail.com>2020-04-02 10:41:24 +0100
commit60943d00bab1682d6fac22b1e1ba5e64569b4e78 (patch)
tree45efa2023c603cc66ac41e9d896a76fad97e9e5e /ot/lp
parent1e2e118e3a30224932ed2f012bb8f9f0f374ef2c (diff)
Auto PEP8
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/__init__.py19
1 files changed, 10 insertions, 9 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 4c968ca..1922785 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -12,16 +12,16 @@ Solvers for the original linear program OT problem
import multiprocessing
import sys
+
import numpy as np
from scipy.sparse import coo_matrix
-from .import cvx
-
+from . import cvx
+from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from ..utils import parmap
-from .cvx import barycenter
from ..utils import dist
+from ..utils import parmap
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -458,7 +458,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
return res
-def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
+def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
+ stopThr=1e-7, verbose=False, log=None):
"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
@@ -525,8 +526,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
T_sum = np.zeros((k, d))
- for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
-
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
+ weights.tolist()):
M_i = dist(X, measure_locations_i)
T_i = emd(b, measure_weights_i, M_i)
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
@@ -651,8 +652,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
if b.ndim == 0 or len(b) == 0:
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
- x_a_1d = x_a.reshape((-1, ))
- x_b_1d = x_b.reshape((-1, ))
+ x_a_1d = x_a.reshape((-1,))
+ x_b_1d = x_b.reshape((-1,))
perm_a = np.argsort(x_a_1d)
perm_b = np.argsort(x_b_1d)