summaryrefslogtreecommitdiff
path: root/ot/lp/emd_wrap.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r--ot/lp/emd_wrap.pyx150
1 files changed, 140 insertions, 10 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 83ee6aa..d345fd4 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -10,13 +10,19 @@ Cython linker with C solver
import numpy as np
cimport numpy as np
+from ..utils import dist
+
cimport cython
+cimport libc.math as math
import warnings
cdef extern from "EMD.h":
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter)
+ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
+ long *iG, long *jG, double *G, long * nG,
+ double* alpha, double* beta, double *cost, int maxIter)
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
@@ -34,9 +40,11 @@ def check_result(result_code):
return message
+
+
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
+def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint dense):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
@@ -55,33 +63,50 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
- M is the metric cost matrix
- a and b are the sample weights
+ .. warning::
+ Note that the M matrix needs to be a C-order :py.cls:`numpy.array`
+
+ .. warning::
+ The C++ solver discards all samples in the distributions with
+ zeros weights. This means that while the primal variable (transport
+ matrix) is exact, the solver only returns feasible dual potentials
+ on the samples with weights different from zero.
+
Parameters
----------
- a : (ns,) ndarray, float64
+ a : (ns,) numpy.ndarray, float64
source histogram
- b : (nt,) ndarray, float64
+ b : (nt,) numpy.ndarray, float64
target histogram
- M : (ns,nt) ndarray, float64
+ M : (ns,nt) numpy.ndarray, float64
loss matrix
max_iter : int
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
-
+ dense : bool
+ Return a sparse transport matrix if set to False
Returns
-------
- gamma: (ns x nt) ndarray
+ gamma: (ns x nt) numpy.ndarray
Optimal transportation matrix for the given parameters
"""
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]
+ cdef int nmax=n1+n2-1
+ cdef int result_code = 0
+ cdef int nG=0
cdef double cost=0
- cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1)
cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2)
+ cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0])
+
+ cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0)
+ cdef np.ndarray[long, ndim=1, mode="c"] iG=np.zeros(0,dtype=np.int)
+ cdef np.ndarray[long, ndim=1, mode="c"] jG=np.zeros(0,dtype=np.int)
if not len(a):
a=np.ones((n1,))/n1
@@ -89,7 +114,112 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
if not len(b):
b=np.ones((n2,))/n2
- # calling the function
- cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
+ if dense:
+ # init OT matrix
+ G=np.zeros([n1, n2])
+
+ # calling the function
+ result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
+
+ return G, cost, alpha, beta, result_code
+
+
+ else:
+
+ # init sparse OT matrix
+ Gv=np.zeros(nmax)
+ iG=np.zeros(nmax,dtype=np.int)
+ jG=np.zeros(nmax,dtype=np.int)
+
+
+ result_code = EMD_wrap_return_sparse(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <long*> iG.data, <long*> jG.data, <double*> Gv.data, <long*> &nG, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
- return G, cost, alpha, beta, result_code
+
+ return Gv[:nG], iG[:nG], jG[:nG], cost, alpha, beta, result_code
+
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
+ np.ndarray[double, ndim=1, mode="c"] v_weights,
+ np.ndarray[double, ndim=1, mode="c"] u,
+ np.ndarray[double, ndim=1, mode="c"] v,
+ str metric='sqeuclidean',
+ double p=1.):
+ r"""
+ Solves the Earth Movers distance problem between sorted 1d measures and
+ returns the OT matrix and the associated cost
+
+ Parameters
+ ----------
+ u_weights : (ns,) ndarray, float64
+ Source histogram
+ v_weights : (nt,) ndarray, float64
+ Target histogram
+ u : (ns,) ndarray, float64
+ Source dirac locations (on the real line)
+ v : (nt,) ndarray, float64
+ Target dirac locations (on the real line)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
+ are used.
+ p: float, optional (default=1.0)
+ The p-norm to apply for if metric='minkowski'
+
+ Returns
+ -------
+ gamma: (n, ) ndarray, float64
+ Values in the Optimal transportation matrix
+ indices: (n, 2) ndarray, int64
+ Indices of the values stored in gamma for the Optimal transportation
+ matrix
+ cost
+ cost associated to the optimal transportation
+ """
+ cdef double cost = 0.
+ cdef int n = u_weights.shape[0]
+ cdef int m = v_weights.shape[0]
+
+ cdef int i = 0
+ cdef double w_i = u_weights[0]
+ cdef int j = 0
+ cdef double w_j = v_weights[0]
+
+ cdef double m_ij = 0.
+
+ cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ),
+ dtype=np.float64)
+ cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
+ dtype=np.int)
+ cdef int cur_idx = 0
+ while i < n and j < m:
+ if metric == 'sqeuclidean':
+ m_ij = (u[i] - v[j]) * (u[i] - v[j])
+ elif metric == 'cityblock' or metric == 'euclidean':
+ m_ij = math.fabs(u[i] - v[j])
+ elif metric == 'minkowski':
+ m_ij = math.pow(math.fabs(u[i] - v[j]), p)
+ else:
+ m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
+ metric=metric)[0, 0]
+ if w_i < w_j or j == m - 1:
+ cost += m_ij * w_i
+ G[cur_idx] = w_i
+ indices[cur_idx, 0] = i
+ indices[cur_idx, 1] = j
+ i += 1
+ w_j -= w_i
+ w_i = u_weights[i]
+ else:
+ cost += m_ij * w_j
+ G[cur_idx] = w_j
+ indices[cur_idx, 0] = i
+ indices[cur_idx, 1] = j
+ j += 1
+ w_i -= w_j
+ w_j = v_weights[j]
+ cur_idx += 1
+ return G[:cur_idx], indices[:cur_idx], cost