diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/__init__.py | 4 | ||||
-rw-r--r-- | ot/lp/__init__.py | 43 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 35 |
3 files changed, 75 insertions, 7 deletions
diff --git a/ot/__init__.py b/ot/__init__.py index b74b924..5d5b700 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -22,7 +22,7 @@ from . import smooth from . import stochastic # OT functions -from .lp import emd, emd2 +from .lp import emd, emd2, emd_1d from .bregman import sinkhorn, sinkhorn2, barycenter from .da import sinkhorn_lpl1_mm @@ -31,6 +31,6 @@ from .utils import dist, unif, tic, toc, toq __version__ = "0.5.1" -__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets', +__all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim'] diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 02cbd8c..49ded5b 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -14,12 +14,12 @@ import numpy as np from .import cvx # import compiled emd -from .emd_wrap import emd_c, check_result +from .emd_wrap import emd_c, check_result, emd_1d_sorted from ..utils import parmap from .cvx import barycenter from ..utils import dist -__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx'] +__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d_sorted'] def emd(a, b, M, numItermax=100000, log=False): @@ -94,7 +94,7 @@ def emd(a, b, M, numItermax=100000, log=False): b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) - # if empty array given then use unifor distributions + # if empty array given then use uniform distributions if len(a) == 0: a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] if len(b) == 0: @@ -187,7 +187,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) - # if empty array given then use unifor distributions + # if empty array given then use uniform distributions if len(a) == 0: a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] if len(b) == 0: @@ -308,4 +308,37 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None log_dict['displacement_square_norms'] = displacement_square_norms return X, log_dict else: - return X
\ No newline at end of file + return X + + +def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', log=False): + """Solves the Earth Movers distance problem between 1d measures and returns + the OT matrix + + """ + assert x_a.shape[1] == x_b.shape[1] == 1, "emd_1d should only be used " + \ + "with monodimensional data" + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + + # if empty array given then use uniform distributions + if len(a) == 0: + a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0] + if len(b) == 0: + b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] + + perm_a = np.argsort(x_a.reshape((-1, ))) + perm_b = np.argsort(x_b.reshape((-1, ))) + inv_perm_a = np.argsort(perm_a) + inv_perm_b = np.argsort(perm_b) + + M = dist(x_a[perm_a], x_b[perm_b], metric=metric) + + G_sorted, cost = emd_1d_sorted(a, b, M) + G = G_sorted[inv_perm_a, :][:, inv_perm_b] + if log: + log = {} + log['cost'] = cost + return G, log + return G
\ No newline at end of file diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 83ee6aa..a3d189d 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -93,3 +93,38 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) return G, cost, alpha, beta, result_code + + +@cython.boundscheck(False) +@cython.wraparound(False) +def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, + np.ndarray[double, ndim=1, mode="c"] v_weights, + np.ndarray[double, ndim=2, mode="c"] M): + r""" + Roro's stuff + """ + cdef double cost = 0. + cdef int n = u_weights.shape[0] + cdef int m = v_weights.shape[0] + + cdef int i = 0 + cdef double w_i = u_weights[0] + cdef int j = 0 + cdef double w_j = v_weights[0] + + cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m), + dtype=np.float64) + while i < n and j < m: + if w_i < w_j or j == m - 1: + cost += M[i, j] * w_i + G[i, j] = w_i + i += 1 + w_j -= w_i + w_i = u_weights[i] + else: + cost += M[i, j] * w_j + G[i, j] = w_j + j += 1 + w_i -= w_j + w_j = v_weights[j] + return G, cost
\ No newline at end of file |