diff options
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/__init__.py | 13 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 3 |
2 files changed, 8 insertions, 8 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 719032b..76c9ec0 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -530,13 +530,13 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False): - """Solves the Wasserstein distance problem between 1d measures and returns + """Solves the p-Wasserstein distance problem between 1d measures and returns the OT matrix .. math:: - \gamma = arg\min_\gamma \left(\sum_i \sum_j \gamma_{ij} - |x_a[i] - x_b[j]|^p \right)^{1/p} + \gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij} + |x_a[i] - x_b[j]|^p \\right)^{1/p} s.t. \gamma 1 = a, \gamma^T 1= b, @@ -617,15 +617,14 @@ def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False): dense=dense, log=log) -def wasserstein2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., - dense=True, log=False): - """Solves the Wasserstein distance problem between 1d measures and returns +def wasserstein2_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False): + """Solves the p-Wasserstein distance problem between 1d measures and returns the loss .. math:: \gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij} - |x_a[i] - x_b[j]|^p \right)^{1/p} + |x_a[i] - x_b[j]|^p \\right)^{1/p} s.t. \gamma 1 = a, \gamma^T 1= b, diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 7134136..42b848f 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -13,6 +13,7 @@ cimport numpy as np from ..utils import dist cimport cython +cimport libc.math as math import warnings @@ -159,7 +160,7 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, elif metric == 'cityblock' or metric == 'euclidean': m_ij = abs(u[i] - v[j]) elif metric == 'minkowski': - m_ij = abs(u[i] - v[j]) ** p + m_ij = math.pow(abs(u[i] - v[j]), p) else: m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), metric=metric)[0, 0] |