diff options
-rw-r--r-- | ot/__init__.py | 4 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 29 | ||||
-rw-r--r-- | test/test_ot.py | 23 |
3 files changed, 41 insertions, 15 deletions
diff --git a/ot/__init__.py b/ot/__init__.py index 5d5b700..f0e526c 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -22,7 +22,7 @@ from . import smooth from . import stochastic # OT functions -from .lp import emd, emd2, emd_1d +from .lp import emd, emd2, emd_1d, emd2_1d from .bregman import sinkhorn, sinkhorn2, barycenter from .da import sinkhorn_lpl1_mm @@ -32,5 +32,5 @@ from .utils import dist, unif, tic, toc, toq __version__ = "0.5.1" __all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets', - 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', + 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', 'emd_1d', 'emd2_1d', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim'] diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 2966206..ab88d7f 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -101,8 +101,8 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod @cython.wraparound(False) def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, np.ndarray[double, ndim=1, mode="c"] v_weights, - np.ndarray[double, ndim=2, mode="c"] u, - np.ndarray[double, ndim=2, mode="c"] v, + np.ndarray[double, ndim=1, mode="c"] u, + np.ndarray[double, ndim=1, mode="c"] v, str metric='sqeuclidean'): r""" Roro's stuff @@ -118,21 +118,34 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cdef double m_ij = 0. - cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m), + cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ), dtype=np.float64) + cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), + dtype=np.int) + cdef int cur_idx = 0 while i < n and j < m: - m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), - metric=metric)[0, 0] + if metric == 'sqeuclidean': + m_ij = (u[i] - v[j]) ** 2 + elif metric == 'cityblock' or metric == 'euclidean': + m_ij = np.abs(u[i] - v[j]) + else: + m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), + metric=metric)[0, 0] if w_i < w_j or j == m - 1: cost += m_ij * w_i - G[i, j] = w_i + G[cur_idx] = w_i + indices[cur_idx, 0] = i + indices[cur_idx, 1] = j i += 1 w_j -= w_i w_i = u_weights[i] else: cost += m_ij * w_j - G[i, j] = w_j + G[cur_idx] = w_j + indices[cur_idx, 0] = i + indices[cur_idx, 1] = j j += 1 w_i -= w_j w_j = v_weights[j] - return G, cost
\ No newline at end of file + cur_idx += 1 + return G[:cur_idx], indices[:cur_idx], cost diff --git a/test/test_ot.py b/test/test_ot.py index 7008002..2a2e0a5 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -7,6 +7,7 @@ import warnings import numpy as np +from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss @@ -37,7 +38,7 @@ def test_emd_emd2(): # check G is identity np.testing.assert_allclose(G, np.eye(n) / n) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn @@ -46,12 +47,13 @@ def test_emd_emd2(): np.testing.assert_allclose(w, 0) -def test_emd1d(): +def test_emd_1d_emd2_1d(): # test emd1d gives similar results as emd n = 20 m = 30 - u = np.random.randn(n, 1) - v = np.random.randn(m, 1) + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) M = ot.dist(u, v, metric='sqeuclidean') @@ -59,9 +61,20 @@ def test_emd1d(): wass = log["cost"] G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True) wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d([], [], u, v, metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d([], [], u, v, metric='euclidean', log=False) # check loss is similar np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, ))) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1)) + np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0)) # check G is similar np.testing.assert_allclose(G, G_1d) @@ -86,7 +99,7 @@ def test_emd_empty(): # check G is identity np.testing.assert_allclose(G, np.eye(n) / n) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn |