summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-12-02 11:31:32 +0100
committerRémi Flamary <remi.flamary@gmail.com>2019-12-02 11:31:32 +0100
commita6a654de5e78dd388a793fbd26f60045b05d519c (patch)
treea8e3049507db770892d05c7747b2bf083c2d9af8 /ot
parent57321bd0172c97b77dfc8b14972c18d063b6dda8 (diff)
proper documentation and parameter
Diffstat (limited to 'ot')
-rw-r--r--ot/lp/EMD.h2
-rw-r--r--ot/lp/EMD_wrapper.cpp3
-rw-r--r--ot/lp/__init__.py16
-rw-r--r--ot/lp/emd_wrap.pyx10
4 files changed, 23 insertions, 8 deletions
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h
index bc513d2..9896091 100644
--- a/ot/lp/EMD.h
+++ b/ot/lp/EMD.h
@@ -33,7 +33,7 @@ enum ProblemType {
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
- long *iG, long *jG, double *G,
+ long *iG, long *jG, double *G, long * nG,
double* alpha, double* beta, double *cost, int maxIter);
#endif
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index 2aa44c1..9be2cdc 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -108,7 +108,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
- long *iG, long *jG, double *G,
+ long *iG, long *jG, double *G, long * nG,
double* alpha, double* beta, double *cost, int maxIter) {
// beware M and C anre strored in row major C style!!!
@@ -202,6 +202,7 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
cur++;
}
}
+ *nG=cur; // nb of value +1 for numpy indexing
}
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 4fec7d9..d476071 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -27,7 +27,7 @@ __all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
-def emd(a, b, M, numItermax=100000, log=False, sparse=False):
+def emd(a, b, M, numItermax=100000, log=False, dense=True):
r"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -62,6 +62,10 @@ def emd(a, b, M, numItermax=100000, log=False, sparse=False):
log: bool, optional (default=False)
If True, returns a dictionary containing the cost and dual
variables. Otherwise returns only the optimal transportation matrix.
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format.
Returns
-------
@@ -103,6 +107,8 @@ def emd(a, b, M, numItermax=100000, log=False, sparse=False):
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
+ sparse= not dense
+
# if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
@@ -128,7 +134,7 @@ def emd(a, b, M, numItermax=100000, log=False, sparse=False):
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
- numItermax=100000, log=False, sparse=False, return_matrix=False):
+ numItermax=100000, log=False, dense=True, return_matrix=False):
r"""Solves the Earth Movers distance problem and returns the loss
.. math::
@@ -166,6 +172,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
variables. Otherwise returns only the optimal transportation cost.
return_matrix: boolean, optional (default=False)
If True, returns the optimal transportation matrix in the log.
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format.
Returns
-------
@@ -207,6 +217,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
+ sparse=not dense
+
# problem with pikling Forks
if sys.platform.endswith('win32'):
processes=1
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index f183995..4b6cdce 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -21,7 +21,7 @@ import warnings
cdef extern from "EMD.h":
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter)
int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D,
- long *iG, long *jG, double *G,
+ long *iG, long *jG, double *G, long * nG,
double* alpha, double* beta, double *cost, int maxIter)
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
@@ -75,7 +75,8 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
max_iter : int
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
-
+ sparse : bool
+ Returning a sparse transport matrix if set to True
Returns
-------
@@ -87,6 +88,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
cdef int n2= M.shape[1]
cdef int nmax=n1+n2-1
cdef int result_code = 0
+ cdef int nG=0
cdef double cost=0
cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1)
@@ -111,10 +113,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
jG=np.zeros(nmax,dtype=np.int)
- result_code = EMD_wrap_return_sparse(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <long*> iG.data, <long*> jG.data, <double*> Gv.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
+ result_code = EMD_wrap_return_sparse(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <long*> iG.data, <long*> jG.data, <double*> Gv.data, <long*> &nG, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
- return Gv, iG, jG, cost, alpha, beta, result_code
+ return Gv[:nG], iG[:nG], jG[:nG], cost, alpha, beta, result_code
else: