From 872e6db7c0d110069b450cbe7efcc186c4871428 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Fri, 21 Oct 2016 11:19:46 +0200 Subject: demo with sinkhorn --- ot/emd/emd.cpp | 267 ++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 238 insertions(+), 29 deletions(-) (limited to 'ot/emd/emd.cpp') diff --git a/ot/emd/emd.cpp b/ot/emd/emd.cpp index 27dc2de..bd40d9c 100644 --- a/ot/emd/emd.cpp +++ b/ot/emd/emd.cpp @@ -896,6 +896,8 @@ static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObjec static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type); +static void __Pyx_RaiseBufferFallbackError(void); + static CYTHON_INLINE void __Pyx_ErrRestore(PyObject *type, PyObject *value, PyObject *tb); static CYTHON_INLINE void __Pyx_ErrFetch(PyObject **type, PyObject **value, PyObject **tb); @@ -1162,6 +1164,7 @@ static char __pyx_k_np[] = "np"; static char __pyx_k_emd[] = "emd"; static char __pyx_k_cost[] = "cost"; static char __pyx_k_main[] = "__main__"; +static char __pyx_k_ones[] = "ones"; static char __pyx_k_test[] = "__test__"; static char __pyx_k_numpy[] = "numpy"; static char __pyx_k_range[] = "range"; @@ -1198,6 +1201,7 @@ static PyObject *__pyx_kp_u_ndarray_is_not_C_contiguous; static PyObject *__pyx_kp_u_ndarray_is_not_Fortran_contiguou; static PyObject *__pyx_n_s_np; static PyObject *__pyx_n_s_numpy; +static PyObject *__pyx_n_s_ones; static PyObject *__pyx_n_s_ot_emd_emd; static PyObject *__pyx_n_s_range; static PyObject *__pyx_n_s_test; @@ -1219,13 +1223,14 @@ static PyObject *__pyx_codeobj__8; * @cython.boundscheck(False) * @cython.wraparound(False) * def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<< - * cdef int n1= M.shape[0] - * cdef int n2= M.shape[1] + * """ + * Solves the Earth Movers distance problem and returns the optimal transport matrix */ /* Python wrapper */ static PyObject *__pyx_pw_2ot_3emd_3emd_1emd(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static PyMethodDef __pyx_mdef_2ot_3emd_3emd_1emd = {"emd", (PyCFunction)__pyx_pw_2ot_3emd_3emd_1emd, METH_VARARGS|METH_KEYWORDS, 0}; +static char __pyx_doc_2ot_3emd_3emd_emd[] = "\n Solves the Earth Movers distance problem and returns the optimal transport matrix\n \n .. math::\n \\gamma = arg\\min_\\gamma <\\gamma,M>_F \n \n s.t. \\gamma 1 = a\n \n \\gamma^T 1= b \n \n \\gamma\\geq 0\n where :\n \n - M is the metric cost matrix\n - a and b are the sample weights\n \n Parameters\n ----------\n a : (ns,) ndarray\n samples in the source domain\n b : (nt,) ndarray\n samples in the target domain\n M : (ns,nt) ndarray\n loss matrix \n \n \n Returns\n -------\n gamma: (ns x nt) ndarray\n Optimal transportation matrix for the given parameters\n \n "; +static PyMethodDef __pyx_mdef_2ot_3emd_3emd_1emd = {"emd", (PyCFunction)__pyx_pw_2ot_3emd_3emd_1emd, METH_VARARGS|METH_KEYWORDS, __pyx_doc_2ot_3emd_3emd_emd}; static PyObject *__pyx_pw_2ot_3emd_3emd_1emd(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { PyArrayObject *__pyx_v_a = 0; PyArrayObject *__pyx_v_b = 0; @@ -1322,10 +1327,20 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_t_4 = NULL; PyObject *__pyx_t_5 = NULL; PyArrayObject *__pyx_t_6 = NULL; + Py_ssize_t __pyx_t_7; + int __pyx_t_8; + PyArrayObject *__pyx_t_9 = NULL; + int __pyx_t_10; + PyObject *__pyx_t_11 = NULL; + PyObject *__pyx_t_12 = NULL; + PyObject *__pyx_t_13 = NULL; + PyArrayObject *__pyx_t_14 = NULL; int __pyx_lineno = 0; const char *__pyx_filename = NULL; int __pyx_clineno = 0; __Pyx_RefNannySetupContext("emd", 0); + __Pyx_INCREF((PyObject *)__pyx_v_a); + __Pyx_INCREF((PyObject *)__pyx_v_b); __pyx_pybuffer_G.pybuffer.buf = NULL; __pyx_pybuffer_G.refcount = 0; __pyx_pybuffernd_G.data = NULL; @@ -1358,17 +1373,17 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self, } __pyx_pybuffernd_M.diminfo[0].strides = __pyx_pybuffernd_M.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_M.diminfo[0].shape = __pyx_pybuffernd_M.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_M.diminfo[1].strides = __pyx_pybuffernd_M.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_M.diminfo[1].shape = __pyx_pybuffernd_M.rcbuffer->pybuffer.shape[1]; - /* "ot/emd/emd.pyx":22 - * @cython.wraparound(False) - * def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): + /* "ot/emd/emd.pyx":54 + * + * """ * cdef int n1= M.shape[0] # <<<<<<<<<<<<<< * cdef int n2= M.shape[1] * */ __pyx_v_n1 = (__pyx_v_M->dimensions[0]); - /* "ot/emd/emd.pyx":23 - * def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): + /* "ot/emd/emd.pyx":55 + * """ * cdef int n1= M.shape[0] * cdef int n2= M.shape[1] # <<<<<<<<<<<<<< * @@ -1376,7 +1391,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self, */ __pyx_v_n2 = (__pyx_v_M->dimensions[1]); - /* "ot/emd/emd.pyx":25 + /* "ot/emd/emd.pyx":57 * cdef int n2= M.shape[1] * * cdef float cost=0 # <<<<<<<<<<<<<< @@ -1385,23 +1400,23 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self, */ __pyx_v_cost = 0.0; - /* "ot/emd/emd.pyx":26 + /* "ot/emd/emd.pyx":58 * * cdef float cost=0 * cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) # <<<<<<<<<<<<<< * - * # calling the function + * if not len(a): */ - __pyx_t_2 = __Pyx_GetModuleGlobalName(__pyx_n_s_np); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __pyx_t_2 = __Pyx_GetModuleGlobalName(__pyx_n_s_np); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;} __Pyx_GOTREF(__pyx_t_2); - __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_zeros); if (unlikely(!__pyx_t_3)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_zeros); if (unlikely(!__pyx_t_3)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;} __Pyx_GOTREF(__pyx_t_3); __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; - __pyx_t_2 = __Pyx_PyInt_From_int(__pyx_v_n1); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __pyx_t_2 = __Pyx_PyInt_From_int(__pyx_v_n1); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;} __Pyx_GOTREF(__pyx_t_2); - __pyx_t_4 = __Pyx_PyInt_From_int(__pyx_v_n2); if (unlikely(!__pyx_t_4)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __pyx_t_4 = __Pyx_PyInt_From_int(__pyx_v_n2); if (unlikely(!__pyx_t_4)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;} __Pyx_GOTREF(__pyx_t_4); - __pyx_t_5 = PyList_New(2); if (unlikely(!__pyx_t_5)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __pyx_t_5 = PyList_New(2); if (unlikely(!__pyx_t_5)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;} __Pyx_GOTREF(__pyx_t_5); __Pyx_GIVEREF(__pyx_t_2); PyList_SET_ITEM(__pyx_t_5, 0, __pyx_t_2); @@ -1420,28 +1435,28 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self, } } if (!__pyx_t_4) { - __pyx_t_1 = __Pyx_PyObject_CallOneArg(__pyx_t_3, __pyx_t_5); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __pyx_t_1 = __Pyx_PyObject_CallOneArg(__pyx_t_3, __pyx_t_5); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;} __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; __Pyx_GOTREF(__pyx_t_1); } else { - __pyx_t_2 = PyTuple_New(1+1); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __pyx_t_2 = PyTuple_New(1+1); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;} __Pyx_GOTREF(__pyx_t_2); __Pyx_GIVEREF(__pyx_t_4); PyTuple_SET_ITEM(__pyx_t_2, 0, __pyx_t_4); __pyx_t_4 = NULL; __Pyx_GIVEREF(__pyx_t_5); PyTuple_SET_ITEM(__pyx_t_2, 0+1, __pyx_t_5); __pyx_t_5 = 0; - __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_3, __pyx_t_2, NULL); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_3, __pyx_t_2, NULL); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;} __Pyx_GOTREF(__pyx_t_1); __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; } __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; - if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_ptype_5numpy_ndarray))))) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_ptype_5numpy_ndarray))))) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;} __pyx_t_6 = ((PyArrayObject *)__pyx_t_1); { __Pyx_BufFmt_StackElem __pyx_stack[1]; if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_G.rcbuffer->pybuffer, (PyObject*)__pyx_t_6, &__Pyx_TypeInfo_double, PyBUF_FORMAT| PyBUF_C_CONTIGUOUS, 2, 0, __pyx_stack) == -1)) { __pyx_v_G = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_G.rcbuffer->pybuffer.buf = NULL; - {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;} } else {__pyx_pybuffernd_G.diminfo[0].strides = __pyx_pybuffernd_G.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_G.diminfo[0].shape = __pyx_pybuffernd_G.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_G.diminfo[1].strides = __pyx_pybuffernd_G.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_G.diminfo[1].shape = __pyx_pybuffernd_G.rcbuffer->pybuffer.shape[1]; } } @@ -1449,7 +1464,193 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self, __pyx_v_G = ((PyArrayObject *)__pyx_t_1); __pyx_t_1 = 0; - /* "ot/emd/emd.pyx":29 + /* "ot/emd/emd.pyx":60 + * cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) + * + * if not len(a): # <<<<<<<<<<<<<< + * a=np.ones((n1,))/n1 + * + */ + __pyx_t_7 = PyObject_Length(((PyObject *)__pyx_v_a)); if (unlikely(__pyx_t_7 == -1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 60; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __pyx_t_8 = ((!(__pyx_t_7 != 0)) != 0); + if (__pyx_t_8) { + + /* "ot/emd/emd.pyx":61 + * + * if not len(a): + * a=np.ones((n1,))/n1 # <<<<<<<<<<<<<< + * + * if not len(b): + */ + __pyx_t_3 = __Pyx_GetModuleGlobalName(__pyx_n_s_np); if (unlikely(!__pyx_t_3)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_3, __pyx_n_s_ones); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_t_3 = __Pyx_PyInt_From_int(__pyx_v_n1); if (unlikely(!__pyx_t_3)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_5 = PyTuple_New(1); if (unlikely(!__pyx_t_5)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_5); + __Pyx_GIVEREF(__pyx_t_3); + PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_t_3); + __pyx_t_3 = 0; + __pyx_t_3 = NULL; + if (CYTHON_COMPILING_IN_CPYTHON && unlikely(PyMethod_Check(__pyx_t_2))) { + __pyx_t_3 = PyMethod_GET_SELF(__pyx_t_2); + if (likely(__pyx_t_3)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); + __Pyx_INCREF(__pyx_t_3); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_2, function); + } + } + if (!__pyx_t_3) { + __pyx_t_1 = __Pyx_PyObject_CallOneArg(__pyx_t_2, __pyx_t_5); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_GOTREF(__pyx_t_1); + } else { + __pyx_t_4 = PyTuple_New(1+1); if (unlikely(!__pyx_t_4)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_3); PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_3); __pyx_t_3 = NULL; + __Pyx_GIVEREF(__pyx_t_5); + PyTuple_SET_ITEM(__pyx_t_4, 0+1, __pyx_t_5); + __pyx_t_5 = 0; + __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_4, NULL); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + } + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyInt_From_int(__pyx_v_n1); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_4 = __Pyx_PyNumber_Divide(__pyx_t_1, __pyx_t_2); if (unlikely(!__pyx_t_4)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + if (!(likely(((__pyx_t_4) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_4, __pyx_ptype_5numpy_ndarray))))) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __pyx_t_9 = ((PyArrayObject *)__pyx_t_4); + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_a.rcbuffer->pybuffer); + __pyx_t_10 = __Pyx_GetBufferAndValidate(&__pyx_pybuffernd_a.rcbuffer->pybuffer, (PyObject*)__pyx_t_9, &__Pyx_TypeInfo_double, PyBUF_FORMAT| PyBUF_C_CONTIGUOUS, 1, 0, __pyx_stack); + if (unlikely(__pyx_t_10 < 0)) { + PyErr_Fetch(&__pyx_t_11, &__pyx_t_12, &__pyx_t_13); + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_a.rcbuffer->pybuffer, (PyObject*)__pyx_v_a, &__Pyx_TypeInfo_double, PyBUF_FORMAT| PyBUF_C_CONTIGUOUS, 1, 0, __pyx_stack) == -1)) { + Py_XDECREF(__pyx_t_11); Py_XDECREF(__pyx_t_12); Py_XDECREF(__pyx_t_13); + __Pyx_RaiseBufferFallbackError(); + } else { + PyErr_Restore(__pyx_t_11, __pyx_t_12, __pyx_t_13); + } + } + __pyx_pybuffernd_a.diminfo[0].strides = __pyx_pybuffernd_a.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_a.diminfo[0].shape = __pyx_pybuffernd_a.rcbuffer->pybuffer.shape[0]; + if (unlikely(__pyx_t_10 < 0)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + } + __pyx_t_9 = 0; + __Pyx_DECREF_SET(__pyx_v_a, ((PyArrayObject *)__pyx_t_4)); + __pyx_t_4 = 0; + + /* "ot/emd/emd.pyx":60 + * cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) + * + * if not len(a): # <<<<<<<<<<<<<< + * a=np.ones((n1,))/n1 + * + */ + } + + /* "ot/emd/emd.pyx":63 + * a=np.ones((n1,))/n1 + * + * if not len(b): # <<<<<<<<<<<<<< + * b=np.ones((n2,))/n2 + * + */ + __pyx_t_7 = PyObject_Length(((PyObject *)__pyx_v_b)); if (unlikely(__pyx_t_7 == -1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 63; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __pyx_t_8 = ((!(__pyx_t_7 != 0)) != 0); + if (__pyx_t_8) { + + /* "ot/emd/emd.pyx":64 + * + * if not len(b): + * b=np.ones((n2,))/n2 # <<<<<<<<<<<<<< + * + * # calling the function + */ + __pyx_t_2 = __Pyx_GetModuleGlobalName(__pyx_n_s_np); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_ones); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyInt_From_int(__pyx_v_n2); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_5 = PyTuple_New(1); if (unlikely(!__pyx_t_5)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_5); + __Pyx_GIVEREF(__pyx_t_2); + PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_t_2); + __pyx_t_2 = 0; + __pyx_t_2 = NULL; + if (CYTHON_COMPILING_IN_CPYTHON && unlikely(PyMethod_Check(__pyx_t_1))) { + __pyx_t_2 = PyMethod_GET_SELF(__pyx_t_1); + if (likely(__pyx_t_2)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_1); + __Pyx_INCREF(__pyx_t_2); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_1, function); + } + } + if (!__pyx_t_2) { + __pyx_t_4 = __Pyx_PyObject_CallOneArg(__pyx_t_1, __pyx_t_5); if (unlikely(!__pyx_t_4)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_GOTREF(__pyx_t_4); + } else { + __pyx_t_3 = PyTuple_New(1+1); if (unlikely(!__pyx_t_3)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_2); PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_2); __pyx_t_2 = NULL; + __Pyx_GIVEREF(__pyx_t_5); + PyTuple_SET_ITEM(__pyx_t_3, 0+1, __pyx_t_5); + __pyx_t_5 = 0; + __pyx_t_4 = __Pyx_PyObject_Call(__pyx_t_1, __pyx_t_3, NULL); if (unlikely(!__pyx_t_4)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + } + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_n2); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = __Pyx_PyNumber_Divide(__pyx_t_4, __pyx_t_1); if (unlikely(!__pyx_t_3)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + if (!(likely(((__pyx_t_3) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_3, __pyx_ptype_5numpy_ndarray))))) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __pyx_t_14 = ((PyArrayObject *)__pyx_t_3); + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_b.rcbuffer->pybuffer); + __pyx_t_10 = __Pyx_GetBufferAndValidate(&__pyx_pybuffernd_b.rcbuffer->pybuffer, (PyObject*)__pyx_t_14, &__Pyx_TypeInfo_double, PyBUF_FORMAT| PyBUF_C_CONTIGUOUS, 1, 0, __pyx_stack); + if (unlikely(__pyx_t_10 < 0)) { + PyErr_Fetch(&__pyx_t_13, &__pyx_t_12, &__pyx_t_11); + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_b.rcbuffer->pybuffer, (PyObject*)__pyx_v_b, &__Pyx_TypeInfo_double, PyBUF_FORMAT| PyBUF_C_CONTIGUOUS, 1, 0, __pyx_stack) == -1)) { + Py_XDECREF(__pyx_t_13); Py_XDECREF(__pyx_t_12); Py_XDECREF(__pyx_t_11); + __Pyx_RaiseBufferFallbackError(); + } else { + PyErr_Restore(__pyx_t_13, __pyx_t_12, __pyx_t_11); + } + } + __pyx_pybuffernd_b.diminfo[0].strides = __pyx_pybuffernd_b.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_b.diminfo[0].shape = __pyx_pybuffernd_b.rcbuffer->pybuffer.shape[0]; + if (unlikely(__pyx_t_10 < 0)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + } + __pyx_t_14 = 0; + __Pyx_DECREF_SET(__pyx_v_b, ((PyArrayObject *)__pyx_t_3)); + __pyx_t_3 = 0; + + /* "ot/emd/emd.pyx":63 + * a=np.ones((n1,))/n1 + * + * if not len(b): # <<<<<<<<<<<<<< + * b=np.ones((n2,))/n2 + * + */ + } + + /* "ot/emd/emd.pyx":67 * * # calling the function * EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, &cost) # <<<<<<<<<<<<<< @@ -1458,7 +1659,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self, */ EMD_wrap(__pyx_v_n1, __pyx_v_n2, ((double *)__pyx_v_a->data), ((double *)__pyx_v_b->data), ((double *)__pyx_v_M->data), ((double *)__pyx_v_G->data), ((double *)(&__pyx_v_cost))); - /* "ot/emd/emd.pyx":31 + /* "ot/emd/emd.pyx":69 * EMD_wrap(n1,n2, a.data, b.data, M.data, G.data, &cost) * * return G # <<<<<<<<<<<<<< @@ -1472,8 +1673,8 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self, * @cython.boundscheck(False) * @cython.wraparound(False) * def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<< - * cdef int n1= M.shape[0] - * cdef int n2= M.shape[1] + * """ + * Solves the Earth Movers distance problem and returns the optimal transport matrix */ /* function exit code */ @@ -1500,6 +1701,8 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self, __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_b.rcbuffer->pybuffer); __pyx_L2:; __Pyx_XDECREF((PyObject *)__pyx_v_G); + __Pyx_XDECREF((PyObject *)__pyx_v_a); + __Pyx_XDECREF((PyObject *)__pyx_v_b); __Pyx_XGIVEREF(__pyx_r); __Pyx_RefNannyFinishContext(); return __pyx_r; @@ -3691,6 +3894,7 @@ static __Pyx_StringTabEntry __pyx_string_tab[] = { {&__pyx_kp_u_ndarray_is_not_Fortran_contiguou, __pyx_k_ndarray_is_not_Fortran_contiguou, sizeof(__pyx_k_ndarray_is_not_Fortran_contiguou), 0, 1, 0, 0}, {&__pyx_n_s_np, __pyx_k_np, sizeof(__pyx_k_np), 0, 0, 1, 1}, {&__pyx_n_s_numpy, __pyx_k_numpy, sizeof(__pyx_k_numpy), 0, 0, 1, 1}, + {&__pyx_n_s_ones, __pyx_k_ones, sizeof(__pyx_k_ones), 0, 0, 1, 1}, {&__pyx_n_s_ot_emd_emd, __pyx_k_ot_emd_emd, sizeof(__pyx_k_ot_emd_emd), 0, 0, 1, 1}, {&__pyx_n_s_range, __pyx_k_range, sizeof(__pyx_k_range), 0, 0, 1, 1}, {&__pyx_n_s_test, __pyx_k_test, sizeof(__pyx_k_test), 0, 0, 1, 1}, @@ -3781,8 +3985,8 @@ static int __Pyx_InitCachedConstants(void) { * @cython.boundscheck(False) * @cython.wraparound(False) * def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<< - * cdef int n1= M.shape[0] - * cdef int n2= M.shape[1] + * """ + * Solves the Earth Movers distance problem and returns the optimal transport matrix */ __pyx_tuple__7 = PyTuple_Pack(7, __pyx_n_s_a, __pyx_n_s_b, __pyx_n_s_M, __pyx_n_s_n1, __pyx_n_s_n2, __pyx_n_s_cost, __pyx_n_s_G); if (unlikely(!__pyx_tuple__7)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L1_error;} __Pyx_GOTREF(__pyx_tuple__7); @@ -3924,8 +4128,8 @@ PyMODINIT_FUNC PyInit_emd(void) * @cython.boundscheck(False) * @cython.wraparound(False) * def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<< - * cdef int n1= M.shape[0] - * cdef int n2= M.shape[1] + * """ + * Solves the Earth Movers distance problem and returns the optimal transport matrix */ __pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_2ot_3emd_3emd_1emd, NULL, __pyx_n_s_ot_emd_emd); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L1_error;} __Pyx_GOTREF(__pyx_t_1); @@ -4817,6 +5021,11 @@ static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) { return 0; } +static void __Pyx_RaiseBufferFallbackError(void) { + PyErr_SetString(PyExc_ValueError, + "Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!"); +} + static CYTHON_INLINE void __Pyx_ErrRestore(PyObject *type, PyObject *value, PyObject *tb) { #if CYTHON_COMPILING_IN_CPYTHON PyObject *tmp_type, *tmp_value, *tmp_tb; -- cgit v1.2.3