summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-28 08:53:16 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-28 08:53:16 +0200
commitcf2d92e151f816e6ddcfc4b64cbda1f8f7bde9df (patch)
tree11cc4cc57453a1ce1739604bd6d64bb6a3bf67e8 /ot
parent123a7c70488d68bff85bbad50c6dd8d390f1e728 (diff)
complete doc emd
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py5
-rw-r--r--ot/lp/__init__.py61
-rw-r--r--ot/lp/emd.cpp126
-rw-r--r--ot/lp/emd.pyx12
4 files changed, 132 insertions, 72 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 17ec06f..0d2c099 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -37,7 +37,7 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
samples in the target domain
M : np.ndarray (ns,nt)
loss matrix
- reg: float()
+ reg: float
Regularization term >0
@@ -54,7 +54,8 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
See Also
--------
- ot.emd.emd : Unregularized optimal ransport
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
"""
# init data
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 46662b7..568e370 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -1,3 +1,62 @@
-from .emd import emd
+from .emd import emd_c
+import numpy as np
+
+def emd(a,b,M):
+ """
+ Solves the Earth Movers distance problem and returns the optimal transport matrix
+
+ gamm=emd(a,b,M)
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the metric cost matrix
+ - a and b are the sample weights
+
+ Parameters
+ ----------
+ a : (ns,) ndarray, float64
+ Source histogram (uniform weigth if empty list)
+ b : (nt,) ndarray, float64
+ Target histogram (uniform weigth if empty list)
+ M : (ns,nt) ndarray, float64
+ loss matrix
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function :func:emd accepts lists and
+ perform automatic conversion tu numpy arrays
+
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.emd(a,b,M)
+ array([[ 0.5, 0. ],
+ [ 0. , 0.5]])
+
+ Returns
+ -------
+ gamma: (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+
+ """
+ a=np.asarray(a,dtype=np.float64)
+ b=np.asarray(b,dtype=np.float64)
+
+ if len(a)==0:
+ a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
+ if len(b)==0:
+ b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
+
+ return emd_c(a,b,M)
+
diff --git a/ot/lp/emd.cpp b/ot/lp/emd.cpp
index 2343af6..26d243f 100644
--- a/ot/lp/emd.cpp
+++ b/ot/lp/emd.cpp
@@ -6,11 +6,11 @@
"depends": [
"/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/arrayobject.h",
"/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/ufuncobject.h",
- "ot/emd/EMD.h"
+ "ot/lp/EMD.h"
],
"include_dirs": [
"/usr/lib/python2.7/dist-packages/numpy/core/include",
- "/home/rflamary/PYTHON/POT/ot/emd"
+ "/home/rflamary/PYTHON/POT/ot/lp"
],
"language": "c++"
}
@@ -260,8 +260,8 @@ static CYTHON_INLINE float __PYX_NAN() {
#endif
#endif
-#define __PYX_HAVE__ot__emd__emd
-#define __PYX_HAVE_API__ot__emd__emd
+#define __PYX_HAVE__ot__lp__emd
+#define __PYX_HAVE_API__ot__lp__emd
#include "string.h"
#include "stdio.h"
#include "stdlib.h"
@@ -497,7 +497,7 @@ static const char *__pyx_filename;
static const char *__pyx_f[] = {
- "ot/emd/emd.pyx",
+ "ot/lp/emd.pyx",
"__init__.pxd",
"type.pxd",
};
@@ -1129,12 +1129,12 @@ static CYTHON_INLINE char *__pyx_f_5numpy__util_dtypestring(PyArray_Descr *, cha
/* Module declarations from 'cython' */
-/* Module declarations from 'ot.emd.emd' */
+/* Module declarations from 'ot.lp.emd' */
static __Pyx_TypeInfo __Pyx_TypeInfo_double = { "double", NULL, sizeof(double), { 0 }, 0, 'R', 0, 0 };
-#define __Pyx_MODULE_NAME "ot.emd.emd"
-int __pyx_module_is_main_ot__emd__emd = 0;
+#define __Pyx_MODULE_NAME "ot.lp.emd"
+int __pyx_module_is_main_ot__lp__emd = 0;
-/* Implementation of 'ot.emd.emd' */
+/* Implementation of 'ot.lp.emd' */
static PyObject *__pyx_builtin_ValueError;
static PyObject *__pyx_builtin_range;
static PyObject *__pyx_builtin_RuntimeError;
@@ -1161,21 +1161,21 @@ static char __pyx_k_Zg[] = "Zg";
static char __pyx_k_n1[] = "n1";
static char __pyx_k_n2[] = "n2";
static char __pyx_k_np[] = "np";
-static char __pyx_k_emd[] = "emd";
static char __pyx_k_cost[] = "cost";
static char __pyx_k_main[] = "__main__";
static char __pyx_k_ones[] = "ones";
static char __pyx_k_test[] = "__test__";
+static char __pyx_k_emd_c[] = "emd_c";
static char __pyx_k_numpy[] = "numpy";
static char __pyx_k_range[] = "range";
static char __pyx_k_zeros[] = "zeros";
static char __pyx_k_import[] = "__import__";
+static char __pyx_k_ot_lp_emd[] = "ot.lp.emd";
static char __pyx_k_ValueError[] = "ValueError";
-static char __pyx_k_ot_emd_emd[] = "ot.emd.emd";
static char __pyx_k_RuntimeError[] = "RuntimeError";
static char __pyx_k_ndarray_is_not_C_contiguous[] = "ndarray is not C contiguous";
static char __pyx_k_Created_on_Thu_Sep_11_08_42_08[] = "\nCreated on Thu Sep 11 08:42:08 2014\n\n@author: rflamary\n";
-static char __pyx_k_home_rflamary_PYTHON_POT_ot_emd[] = "/home/rflamary/PYTHON/POT/ot/emd/emd.pyx";
+static char __pyx_k_home_rflamary_PYTHON_POT_ot_lp[] = "/home/rflamary/PYTHON/POT/ot/lp/emd.pyx";
static char __pyx_k_unknown_dtype_code_in_numpy_pxd[] = "unknown dtype code in numpy.pxd (%d)";
static char __pyx_k_Format_string_allocated_too_shor[] = "Format string allocated too short, see comment in numpy.pxd";
static char __pyx_k_Non_native_byte_order_not_suppor[] = "Non-native byte order not supported";
@@ -1191,8 +1191,8 @@ static PyObject *__pyx_n_s_ValueError;
static PyObject *__pyx_n_s_a;
static PyObject *__pyx_n_s_b;
static PyObject *__pyx_n_s_cost;
-static PyObject *__pyx_n_s_emd;
-static PyObject *__pyx_kp_s_home_rflamary_PYTHON_POT_ot_emd;
+static PyObject *__pyx_n_s_emd_c;
+static PyObject *__pyx_kp_s_home_rflamary_PYTHON_POT_ot_lp;
static PyObject *__pyx_n_s_import;
static PyObject *__pyx_n_s_main;
static PyObject *__pyx_n_s_n1;
@@ -1202,12 +1202,12 @@ static PyObject *__pyx_kp_u_ndarray_is_not_Fortran_contiguou;
static PyObject *__pyx_n_s_np;
static PyObject *__pyx_n_s_numpy;
static PyObject *__pyx_n_s_ones;
-static PyObject *__pyx_n_s_ot_emd_emd;
+static PyObject *__pyx_n_s_ot_lp_emd;
static PyObject *__pyx_n_s_range;
static PyObject *__pyx_n_s_test;
static PyObject *__pyx_kp_u_unknown_dtype_code_in_numpy_pxd;
static PyObject *__pyx_n_s_zeros;
-static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_a, PyArrayObject *__pyx_v_b, PyArrayObject *__pyx_v_M); /* proto */
+static PyObject *__pyx_pf_2ot_2lp_3emd_emd_c(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_a, PyArrayObject *__pyx_v_b, PyArrayObject *__pyx_v_M); /* proto */
static int __pyx_pf_5numpy_7ndarray___getbuffer__(PyArrayObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /* proto */
static void __pyx_pf_5numpy_7ndarray_2__releasebuffer__(PyArrayObject *__pyx_v_self, Py_buffer *__pyx_v_info); /* proto */
static PyObject *__pyx_tuple_;
@@ -1219,19 +1219,19 @@ static PyObject *__pyx_tuple__6;
static PyObject *__pyx_tuple__7;
static PyObject *__pyx_codeobj__8;
-/* "ot/emd/emd.pyx":21
+/* "ot/lp/emd.pyx":21
* @cython.boundscheck(False)
* @cython.wraparound(False)
- * def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<<
+ * def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<<
* """
* Solves the Earth Movers distance problem and returns the optimal transport matrix
*/
/* Python wrapper */
-static PyObject *__pyx_pw_2ot_3emd_3emd_1emd(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
-static char __pyx_doc_2ot_3emd_3emd_emd[] = "\n Solves the Earth Movers distance problem and returns the optimal transport matrix\n \n gamm=emd(a,b,M)\n \n .. math::\n \\gamma = arg\\min_\\gamma <\\gamma,M>_F \n \n s.t. \\gamma 1 = a\n \n \\gamma^T 1= b \n \n \\gamma\\geq 0\n where :\n \n - M is the metric cost matrix\n - a and b are the sample weights\n \n Parameters\n ----------\n a : (ns,) ndarray\n samples in the source domain (uniform waigth if empty)\n b : (nt,) ndarray\n samples in the target domain (uniform waigth if empty)\n M : (ns,nt) ndarray\n loss matrix \n \n \n Returns\n -------\n gamma: (ns x nt) ndarray\n Optimal transportation matrix for the given parameters\n \n ";
-static PyMethodDef __pyx_mdef_2ot_3emd_3emd_1emd = {"emd", (PyCFunction)__pyx_pw_2ot_3emd_3emd_1emd, METH_VARARGS|METH_KEYWORDS, __pyx_doc_2ot_3emd_3emd_emd};
-static PyObject *__pyx_pw_2ot_3emd_3emd_1emd(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
+static PyObject *__pyx_pw_2ot_2lp_3emd_1emd_c(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
+static char __pyx_doc_2ot_2lp_3emd_emd_c[] = "\n Solves the Earth Movers distance problem and returns the optimal transport matrix\n \n gamm=emd(a,b,M)\n \n .. math::\n \\gamma = arg\\min_\\gamma <\\gamma,M>_F \n \n s.t. \\gamma 1 = a\n \n \\gamma^T 1= b \n \n \\gamma\\geq 0\n where :\n \n - M is the metric cost matrix\n - a and b are the sample weights\n \n Parameters\n ----------\n a : (ns,) ndarray\n source histogram \n b : (nt,) ndarray\n target histogram\n M : (ns,nt) ndarray\n loss matrix \n \n \n Returns\n -------\n gamma: (ns x nt) ndarray\n Optimal transportation matrix for the given parameters\n \n ";
+static PyMethodDef __pyx_mdef_2ot_2lp_3emd_1emd_c = {"emd_c", (PyCFunction)__pyx_pw_2ot_2lp_3emd_1emd_c, METH_VARARGS|METH_KEYWORDS, __pyx_doc_2ot_2lp_3emd_emd_c};
+static PyObject *__pyx_pw_2ot_2lp_3emd_1emd_c(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
PyArrayObject *__pyx_v_a = 0;
PyArrayObject *__pyx_v_b = 0;
PyArrayObject *__pyx_v_M = 0;
@@ -1240,7 +1240,7 @@ static PyObject *__pyx_pw_2ot_3emd_3emd_1emd(PyObject *__pyx_self, PyObject *__p
int __pyx_clineno = 0;
PyObject *__pyx_r = 0;
__Pyx_RefNannyDeclarations
- __Pyx_RefNannySetupContext("emd (wrapper)", 0);
+ __Pyx_RefNannySetupContext("emd_c (wrapper)", 0);
{
static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_a,&__pyx_n_s_b,&__pyx_n_s_M,0};
PyObject* values[3] = {0,0,0};
@@ -1262,16 +1262,16 @@ static PyObject *__pyx_pw_2ot_3emd_3emd_1emd(PyObject *__pyx_self, PyObject *__p
case 1:
if (likely((values[1] = PyDict_GetItem(__pyx_kwds, __pyx_n_s_b)) != 0)) kw_args--;
else {
- __Pyx_RaiseArgtupleInvalid("emd", 1, 3, 3, 1); {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L3_error;}
+ __Pyx_RaiseArgtupleInvalid("emd_c", 1, 3, 3, 1); {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L3_error;}
}
case 2:
if (likely((values[2] = PyDict_GetItem(__pyx_kwds, __pyx_n_s_M)) != 0)) kw_args--;
else {
- __Pyx_RaiseArgtupleInvalid("emd", 1, 3, 3, 2); {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L3_error;}
+ __Pyx_RaiseArgtupleInvalid("emd_c", 1, 3, 3, 2); {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L3_error;}
}
}
if (unlikely(kw_args > 0)) {
- if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "emd") < 0)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L3_error;}
+ if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "emd_c") < 0)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L3_error;}
}
} else if (PyTuple_GET_SIZE(__pyx_args) != 3) {
goto __pyx_L5_argtuple_error;
@@ -1286,16 +1286,16 @@ static PyObject *__pyx_pw_2ot_3emd_3emd_1emd(PyObject *__pyx_self, PyObject *__p
}
goto __pyx_L4_argument_unpacking_done;
__pyx_L5_argtuple_error:;
- __Pyx_RaiseArgtupleInvalid("emd", 1, 3, 3, PyTuple_GET_SIZE(__pyx_args)); {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L3_error;}
+ __Pyx_RaiseArgtupleInvalid("emd_c", 1, 3, 3, PyTuple_GET_SIZE(__pyx_args)); {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L3_error;}
__pyx_L3_error:;
- __Pyx_AddTraceback("ot.emd.emd.emd", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __Pyx_AddTraceback("ot.lp.emd.emd_c", __pyx_clineno, __pyx_lineno, __pyx_filename);
__Pyx_RefNannyFinishContext();
return NULL;
__pyx_L4_argument_unpacking_done:;
if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_a), __pyx_ptype_5numpy_ndarray, 1, "a", 0))) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_b), __pyx_ptype_5numpy_ndarray, 1, "b", 0))) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_M), __pyx_ptype_5numpy_ndarray, 1, "M", 0))) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
- __pyx_r = __pyx_pf_2ot_3emd_3emd_emd(__pyx_self, __pyx_v_a, __pyx_v_b, __pyx_v_M);
+ __pyx_r = __pyx_pf_2ot_2lp_3emd_emd_c(__pyx_self, __pyx_v_a, __pyx_v_b, __pyx_v_M);
/* function exit code */
goto __pyx_L0;
@@ -1306,7 +1306,7 @@ static PyObject *__pyx_pw_2ot_3emd_3emd_1emd(PyObject *__pyx_self, PyObject *__p
return __pyx_r;
}
-static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_a, PyArrayObject *__pyx_v_b, PyArrayObject *__pyx_v_M) {
+static PyObject *__pyx_pf_2ot_2lp_3emd_emd_c(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_a, PyArrayObject *__pyx_v_b, PyArrayObject *__pyx_v_M) {
int __pyx_v_n1;
int __pyx_v_n2;
float __pyx_v_cost;
@@ -1338,7 +1338,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
int __pyx_lineno = 0;
const char *__pyx_filename = NULL;
int __pyx_clineno = 0;
- __Pyx_RefNannySetupContext("emd", 0);
+ __Pyx_RefNannySetupContext("emd_c", 0);
__Pyx_INCREF((PyObject *)__pyx_v_a);
__Pyx_INCREF((PyObject *)__pyx_v_b);
__pyx_pybuffer_G.pybuffer.buf = NULL;
@@ -1373,7 +1373,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
}
__pyx_pybuffernd_M.diminfo[0].strides = __pyx_pybuffernd_M.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_M.diminfo[0].shape = __pyx_pybuffernd_M.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_M.diminfo[1].strides = __pyx_pybuffernd_M.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_M.diminfo[1].shape = __pyx_pybuffernd_M.rcbuffer->pybuffer.shape[1];
- /* "ot/emd/emd.pyx":56
+ /* "ot/lp/emd.pyx":56
*
* """
* cdef int n1= M.shape[0] # <<<<<<<<<<<<<<
@@ -1382,7 +1382,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
*/
__pyx_v_n1 = (__pyx_v_M->dimensions[0]);
- /* "ot/emd/emd.pyx":57
+ /* "ot/lp/emd.pyx":57
* """
* cdef int n1= M.shape[0]
* cdef int n2= M.shape[1] # <<<<<<<<<<<<<<
@@ -1391,7 +1391,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
*/
__pyx_v_n2 = (__pyx_v_M->dimensions[1]);
- /* "ot/emd/emd.pyx":59
+ /* "ot/lp/emd.pyx":59
* cdef int n2= M.shape[1]
*
* cdef float cost=0 # <<<<<<<<<<<<<<
@@ -1400,7 +1400,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
*/
__pyx_v_cost = 0.0;
- /* "ot/emd/emd.pyx":60
+ /* "ot/lp/emd.pyx":60
*
* cdef float cost=0
* cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) # <<<<<<<<<<<<<<
@@ -1464,7 +1464,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
__pyx_v_G = ((PyArrayObject *)__pyx_t_1);
__pyx_t_1 = 0;
- /* "ot/emd/emd.pyx":62
+ /* "ot/lp/emd.pyx":62
* cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
*
* if not len(a): # <<<<<<<<<<<<<<
@@ -1475,7 +1475,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
__pyx_t_8 = ((!(__pyx_t_7 != 0)) != 0);
if (__pyx_t_8) {
- /* "ot/emd/emd.pyx":63
+ /* "ot/lp/emd.pyx":63
*
* if not len(a):
* a=np.ones((n1,))/n1 # <<<<<<<<<<<<<<
@@ -1548,7 +1548,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
__Pyx_DECREF_SET(__pyx_v_a, ((PyArrayObject *)__pyx_t_4));
__pyx_t_4 = 0;
- /* "ot/emd/emd.pyx":62
+ /* "ot/lp/emd.pyx":62
* cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
*
* if not len(a): # <<<<<<<<<<<<<<
@@ -1557,7 +1557,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
*/
}
- /* "ot/emd/emd.pyx":65
+ /* "ot/lp/emd.pyx":65
* a=np.ones((n1,))/n1
*
* if not len(b): # <<<<<<<<<<<<<<
@@ -1568,7 +1568,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
__pyx_t_8 = ((!(__pyx_t_7 != 0)) != 0);
if (__pyx_t_8) {
- /* "ot/emd/emd.pyx":66
+ /* "ot/lp/emd.pyx":66
*
* if not len(b):
* b=np.ones((n2,))/n2 # <<<<<<<<<<<<<<
@@ -1641,7 +1641,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
__Pyx_DECREF_SET(__pyx_v_b, ((PyArrayObject *)__pyx_t_3));
__pyx_t_3 = 0;
- /* "ot/emd/emd.pyx":65
+ /* "ot/lp/emd.pyx":65
* a=np.ones((n1,))/n1
*
* if not len(b): # <<<<<<<<<<<<<<
@@ -1650,7 +1650,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
*/
}
- /* "ot/emd/emd.pyx":69
+ /* "ot/lp/emd.pyx":69
*
* # calling the function
* EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost) # <<<<<<<<<<<<<<
@@ -1659,7 +1659,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
*/
EMD_wrap(__pyx_v_n1, __pyx_v_n2, ((double *)__pyx_v_a->data), ((double *)__pyx_v_b->data), ((double *)__pyx_v_M->data), ((double *)__pyx_v_G->data), ((double *)(&__pyx_v_cost)));
- /* "ot/emd/emd.pyx":71
+ /* "ot/lp/emd.pyx":71
* EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost)
*
* return G # <<<<<<<<<<<<<<
@@ -1669,10 +1669,10 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
__pyx_r = ((PyObject *)__pyx_v_G);
goto __pyx_L0;
- /* "ot/emd/emd.pyx":21
+ /* "ot/lp/emd.pyx":21
* @cython.boundscheck(False)
* @cython.wraparound(False)
- * def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<<
+ * def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<<
* """
* Solves the Earth Movers distance problem and returns the optimal transport matrix
*/
@@ -1691,7 +1691,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
__Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_a.rcbuffer->pybuffer);
__Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_b.rcbuffer->pybuffer);
__Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);}
- __Pyx_AddTraceback("ot.emd.emd.emd", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __Pyx_AddTraceback("ot.lp.emd.emd_c", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = NULL;
goto __pyx_L2;
__pyx_L0:;
@@ -3884,8 +3884,8 @@ static __Pyx_StringTabEntry __pyx_string_tab[] = {
{&__pyx_n_s_a, __pyx_k_a, sizeof(__pyx_k_a), 0, 0, 1, 1},
{&__pyx_n_s_b, __pyx_k_b, sizeof(__pyx_k_b), 0, 0, 1, 1},
{&__pyx_n_s_cost, __pyx_k_cost, sizeof(__pyx_k_cost), 0, 0, 1, 1},
- {&__pyx_n_s_emd, __pyx_k_emd, sizeof(__pyx_k_emd), 0, 0, 1, 1},
- {&__pyx_kp_s_home_rflamary_PYTHON_POT_ot_emd, __pyx_k_home_rflamary_PYTHON_POT_ot_emd, sizeof(__pyx_k_home_rflamary_PYTHON_POT_ot_emd), 0, 0, 1, 0},
+ {&__pyx_n_s_emd_c, __pyx_k_emd_c, sizeof(__pyx_k_emd_c), 0, 0, 1, 1},
+ {&__pyx_kp_s_home_rflamary_PYTHON_POT_ot_lp, __pyx_k_home_rflamary_PYTHON_POT_ot_lp, sizeof(__pyx_k_home_rflamary_PYTHON_POT_ot_lp), 0, 0, 1, 0},
{&__pyx_n_s_import, __pyx_k_import, sizeof(__pyx_k_import), 0, 0, 1, 1},
{&__pyx_n_s_main, __pyx_k_main, sizeof(__pyx_k_main), 0, 0, 1, 1},
{&__pyx_n_s_n1, __pyx_k_n1, sizeof(__pyx_k_n1), 0, 0, 1, 1},
@@ -3895,7 +3895,7 @@ static __Pyx_StringTabEntry __pyx_string_tab[] = {
{&__pyx_n_s_np, __pyx_k_np, sizeof(__pyx_k_np), 0, 0, 1, 1},
{&__pyx_n_s_numpy, __pyx_k_numpy, sizeof(__pyx_k_numpy), 0, 0, 1, 1},
{&__pyx_n_s_ones, __pyx_k_ones, sizeof(__pyx_k_ones), 0, 0, 1, 1},
- {&__pyx_n_s_ot_emd_emd, __pyx_k_ot_emd_emd, sizeof(__pyx_k_ot_emd_emd), 0, 0, 1, 1},
+ {&__pyx_n_s_ot_lp_emd, __pyx_k_ot_lp_emd, sizeof(__pyx_k_ot_lp_emd), 0, 0, 1, 1},
{&__pyx_n_s_range, __pyx_k_range, sizeof(__pyx_k_range), 0, 0, 1, 1},
{&__pyx_n_s_test, __pyx_k_test, sizeof(__pyx_k_test), 0, 0, 1, 1},
{&__pyx_kp_u_unknown_dtype_code_in_numpy_pxd, __pyx_k_unknown_dtype_code_in_numpy_pxd, sizeof(__pyx_k_unknown_dtype_code_in_numpy_pxd), 0, 1, 0, 0},
@@ -3981,17 +3981,17 @@ static int __Pyx_InitCachedConstants(void) {
__Pyx_GOTREF(__pyx_tuple__6);
__Pyx_GIVEREF(__pyx_tuple__6);
- /* "ot/emd/emd.pyx":21
+ /* "ot/lp/emd.pyx":21
* @cython.boundscheck(False)
* @cython.wraparound(False)
- * def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<<
+ * def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<<
* """
* Solves the Earth Movers distance problem and returns the optimal transport matrix
*/
__pyx_tuple__7 = PyTuple_Pack(7, __pyx_n_s_a, __pyx_n_s_b, __pyx_n_s_M, __pyx_n_s_n1, __pyx_n_s_n2, __pyx_n_s_cost, __pyx_n_s_G); if (unlikely(!__pyx_tuple__7)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_tuple__7);
__Pyx_GIVEREF(__pyx_tuple__7);
- __pyx_codeobj__8 = (PyObject*)__Pyx_PyCode_New(3, 0, 7, 0, 0, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__7, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_home_rflamary_PYTHON_POT_ot_emd, __pyx_n_s_emd, 21, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__8)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_codeobj__8 = (PyObject*)__Pyx_PyCode_New(3, 0, 7, 0, 0, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__7, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_home_rflamary_PYTHON_POT_ot_lp, __pyx_n_s_emd_c, 21, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__8)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_RefNannyFinishContext();
return 0;
__pyx_L1_error:;
@@ -4073,14 +4073,14 @@ PyMODINIT_FUNC PyInit_emd(void)
#if PY_MAJOR_VERSION < 3 && (__PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT)
if (__Pyx_init_sys_getdefaultencoding_params() < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 1; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
#endif
- if (__pyx_module_is_main_ot__emd__emd) {
+ if (__pyx_module_is_main_ot__lp__emd) {
if (PyObject_SetAttrString(__pyx_m, "__name__", __pyx_n_s_main) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 1; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
}
#if PY_MAJOR_VERSION >= 3
{
PyObject *modules = PyImport_GetModuleDict(); if (unlikely(!modules)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 1; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
- if (!PyDict_GetItemString(modules, "ot.emd.emd")) {
- if (unlikely(PyDict_SetItemString(modules, "ot.emd.emd", __pyx_m) < 0)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 1; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ if (!PyDict_GetItemString(modules, "ot.lp.emd")) {
+ if (unlikely(PyDict_SetItemString(modules, "ot.lp.emd", __pyx_m) < 0)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 1; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
}
}
#endif
@@ -4112,7 +4112,7 @@ PyMODINIT_FUNC PyInit_emd(void)
if (__Pyx_patch_abc() < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 1; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
#endif
- /* "ot/emd/emd.pyx":7
+ /* "ot/lp/emd.pyx":7
* @author: rflamary
* """
* import numpy as np # <<<<<<<<<<<<<<
@@ -4124,19 +4124,19 @@ PyMODINIT_FUNC PyInit_emd(void)
if (PyDict_SetItem(__pyx_d, __pyx_n_s_np, __pyx_t_1) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 7; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
- /* "ot/emd/emd.pyx":21
+ /* "ot/lp/emd.pyx":21
* @cython.boundscheck(False)
* @cython.wraparound(False)
- * def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<<
+ * def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<<
* """
* Solves the Earth Movers distance problem and returns the optimal transport matrix
*/
- __pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_2ot_3emd_3emd_1emd, NULL, __pyx_n_s_ot_emd_emd); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_2ot_2lp_3emd_1emd_c, NULL, __pyx_n_s_ot_lp_emd); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_1);
- if (PyDict_SetItem(__pyx_d, __pyx_n_s_emd, __pyx_t_1) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ if (PyDict_SetItem(__pyx_d, __pyx_n_s_emd_c, __pyx_t_1) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
- /* "ot/emd/emd.pyx":1
+ /* "ot/lp/emd.pyx":1
* # -*- coding: utf-8 -*- # <<<<<<<<<<<<<<
* """
* Created on Thu Sep 11 08:42:08 2014
@@ -4161,11 +4161,11 @@ PyMODINIT_FUNC PyInit_emd(void)
__Pyx_XDECREF(__pyx_t_1);
if (__pyx_m) {
if (__pyx_d) {
- __Pyx_AddTraceback("init ot.emd.emd", __pyx_clineno, __pyx_lineno, __pyx_filename);
+ __Pyx_AddTraceback("init ot.lp.emd", __pyx_clineno, __pyx_lineno, __pyx_filename);
}
Py_DECREF(__pyx_m); __pyx_m = 0;
} else if (!PyErr_Occurred()) {
- PyErr_SetString(PyExc_ImportError, "init ot.emd.emd");
+ PyErr_SetString(PyExc_ImportError, "init ot.lp.emd");
}
__pyx_L0:;
__Pyx_RefNannyFinishContext();
diff --git a/ot/lp/emd.pyx b/ot/lp/emd.pyx
index 753b195..de2d4a9 100644
--- a/ot/lp/emd.pyx
+++ b/ot/lp/emd.pyx
@@ -18,7 +18,7 @@ cdef extern from "EMD.h":
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M):
+def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
@@ -39,11 +39,11 @@ def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode=
Parameters
----------
- a : (ns,) ndarray
- samples in the source domain (uniform waigth if empty)
- b : (nt,) ndarray
- samples in the target domain (uniform waigth if empty)
- M : (ns,nt) ndarray
+ a : (ns,) ndarray, float64
+ source histogram
+ b : (nt,) ndarray, float64
+ target histogram
+ M : (ns,nt) ndarray, float64
loss matrix