summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-21 11:19:46 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-21 11:19:46 +0200
commit872e6db7c0d110069b450cbe7efcc186c4871428 (patch)
treef499a0258cf69f47a211a54447af990ac0afb591 /ot
parent581c6de782dca279edd97778cc474e7597788c0f (diff)
demo with sinkhorn
Diffstat (limited to 'ot')
-rw-r--r--ot/__init__.py8
-rw-r--r--ot/bregman/sink.py (renamed from ot/bregman/sinkhorn.py)2
-rw-r--r--ot/emd/emd.cpp267
-rw-r--r--ot/emd/emd.pyx44
-rw-r--r--ot/utils.py13
5 files changed, 289 insertions, 45 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index beeae7f..14c6181 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -1,11 +1,13 @@
+# utils submodules
+import utils
+import datasets
-
+# Ot functions
from emd import emd
from bregman import sinkhorn
-import utils
-import datasets
+
from utils import dist,dots
diff --git a/ot/bregman/sinkhorn.py b/ot/bregman/sink.py
index 798ac97..8b97e1e 100644
--- a/ot/bregman/sinkhorn.py
+++ b/ot/bregman/sink.py
@@ -59,7 +59,7 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
#print reg
- K = np.exp(-reg*M)
+ K = np.exp(-M/reg)
#print np.min(K)
Kp = np.dot(np.diag(1/a),K)
diff --git a/ot/emd/emd.cpp b/ot/emd/emd.cpp
index 27dc2de..bd40d9c 100644
--- a/ot/emd/emd.cpp
+++ b/ot/emd/emd.cpp
@@ -896,6 +896,8 @@ static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObjec
static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type);
+static void __Pyx_RaiseBufferFallbackError(void);
+
static CYTHON_INLINE void __Pyx_ErrRestore(PyObject *type, PyObject *value, PyObject *tb);
static CYTHON_INLINE void __Pyx_ErrFetch(PyObject **type, PyObject **value, PyObject **tb);
@@ -1162,6 +1164,7 @@ static char __pyx_k_np[] = "np";
static char __pyx_k_emd[] = "emd";
static char __pyx_k_cost[] = "cost";
static char __pyx_k_main[] = "__main__";
+static char __pyx_k_ones[] = "ones";
static char __pyx_k_test[] = "__test__";
static char __pyx_k_numpy[] = "numpy";
static char __pyx_k_range[] = "range";
@@ -1198,6 +1201,7 @@ static PyObject *__pyx_kp_u_ndarray_is_not_C_contiguous;
static PyObject *__pyx_kp_u_ndarray_is_not_Fortran_contiguou;
static PyObject *__pyx_n_s_np;
static PyObject *__pyx_n_s_numpy;
+static PyObject *__pyx_n_s_ones;
static PyObject *__pyx_n_s_ot_emd_emd;
static PyObject *__pyx_n_s_range;
static PyObject *__pyx_n_s_test;
@@ -1219,13 +1223,14 @@ static PyObject *__pyx_codeobj__8;
* @cython.boundscheck(False)
* @cython.wraparound(False)
* def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<<
- * cdef int n1= M.shape[0]
- * cdef int n2= M.shape[1]
+ * """
+ * Solves the Earth Movers distance problem and returns the optimal transport matrix
*/
/* Python wrapper */
static PyObject *__pyx_pw_2ot_3emd_3emd_1emd(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
-static PyMethodDef __pyx_mdef_2ot_3emd_3emd_1emd = {"emd", (PyCFunction)__pyx_pw_2ot_3emd_3emd_1emd, METH_VARARGS|METH_KEYWORDS, 0};
+static char __pyx_doc_2ot_3emd_3emd_emd[] = "\n Solves the Earth Movers distance problem and returns the optimal transport matrix\n \n .. math::\n \\gamma = arg\\min_\\gamma <\\gamma,M>_F \n \n s.t. \\gamma 1 = a\n \n \\gamma^T 1= b \n \n \\gamma\\geq 0\n where :\n \n - M is the metric cost matrix\n - a and b are the sample weights\n \n Parameters\n ----------\n a : (ns,) ndarray\n samples in the source domain\n b : (nt,) ndarray\n samples in the target domain\n M : (ns,nt) ndarray\n loss matrix \n \n \n Returns\n -------\n gamma: (ns x nt) ndarray\n Optimal transportation matrix for the given parameters\n \n ";
+static PyMethodDef __pyx_mdef_2ot_3emd_3emd_1emd = {"emd", (PyCFunction)__pyx_pw_2ot_3emd_3emd_1emd, METH_VARARGS|METH_KEYWORDS, __pyx_doc_2ot_3emd_3emd_emd};
static PyObject *__pyx_pw_2ot_3emd_3emd_1emd(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
PyArrayObject *__pyx_v_a = 0;
PyArrayObject *__pyx_v_b = 0;
@@ -1322,10 +1327,20 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
PyObject *__pyx_t_4 = NULL;
PyObject *__pyx_t_5 = NULL;
PyArrayObject *__pyx_t_6 = NULL;
+ Py_ssize_t __pyx_t_7;
+ int __pyx_t_8;
+ PyArrayObject *__pyx_t_9 = NULL;
+ int __pyx_t_10;
+ PyObject *__pyx_t_11 = NULL;
+ PyObject *__pyx_t_12 = NULL;
+ PyObject *__pyx_t_13 = NULL;
+ PyArrayObject *__pyx_t_14 = NULL;
int __pyx_lineno = 0;
const char *__pyx_filename = NULL;
int __pyx_clineno = 0;
__Pyx_RefNannySetupContext("emd", 0);
+ __Pyx_INCREF((PyObject *)__pyx_v_a);
+ __Pyx_INCREF((PyObject *)__pyx_v_b);
__pyx_pybuffer_G.pybuffer.buf = NULL;
__pyx_pybuffer_G.refcount = 0;
__pyx_pybuffernd_G.data = NULL;
@@ -1358,17 +1373,17 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
}
__pyx_pybuffernd_M.diminfo[0].strides = __pyx_pybuffernd_M.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_M.diminfo[0].shape = __pyx_pybuffernd_M.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_M.diminfo[1].strides = __pyx_pybuffernd_M.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_M.diminfo[1].shape = __pyx_pybuffernd_M.rcbuffer->pybuffer.shape[1];
- /* "ot/emd/emd.pyx":22
- * @cython.wraparound(False)
- * def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M):
+ /* "ot/emd/emd.pyx":54
+ *
+ * """
* cdef int n1= M.shape[0] # <<<<<<<<<<<<<<
* cdef int n2= M.shape[1]
*
*/
__pyx_v_n1 = (__pyx_v_M->dimensions[0]);
- /* "ot/emd/emd.pyx":23
- * def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M):
+ /* "ot/emd/emd.pyx":55
+ * """
* cdef int n1= M.shape[0]
* cdef int n2= M.shape[1] # <<<<<<<<<<<<<<
*
@@ -1376,7 +1391,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
*/
__pyx_v_n2 = (__pyx_v_M->dimensions[1]);
- /* "ot/emd/emd.pyx":25
+ /* "ot/emd/emd.pyx":57
* cdef int n2= M.shape[1]
*
* cdef float cost=0 # <<<<<<<<<<<<<<
@@ -1385,23 +1400,23 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
*/
__pyx_v_cost = 0.0;
- /* "ot/emd/emd.pyx":26
+ /* "ot/emd/emd.pyx":58
*
* cdef float cost=0
* cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) # <<<<<<<<<<<<<<
*
- * # calling the function
+ * if not len(a):
*/
- __pyx_t_2 = __Pyx_GetModuleGlobalName(__pyx_n_s_np); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_t_2 = __Pyx_GetModuleGlobalName(__pyx_n_s_np); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_2);
- __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_zeros); if (unlikely(!__pyx_t_3)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_zeros); if (unlikely(!__pyx_t_3)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_3);
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
- __pyx_t_2 = __Pyx_PyInt_From_int(__pyx_v_n1); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_t_2 = __Pyx_PyInt_From_int(__pyx_v_n1); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_2);
- __pyx_t_4 = __Pyx_PyInt_From_int(__pyx_v_n2); if (unlikely(!__pyx_t_4)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_t_4 = __Pyx_PyInt_From_int(__pyx_v_n2); if (unlikely(!__pyx_t_4)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_4);
- __pyx_t_5 = PyList_New(2); if (unlikely(!__pyx_t_5)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_t_5 = PyList_New(2); if (unlikely(!__pyx_t_5)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_5);
__Pyx_GIVEREF(__pyx_t_2);
PyList_SET_ITEM(__pyx_t_5, 0, __pyx_t_2);
@@ -1420,28 +1435,28 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
}
}
if (!__pyx_t_4) {
- __pyx_t_1 = __Pyx_PyObject_CallOneArg(__pyx_t_3, __pyx_t_5); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_t_1 = __Pyx_PyObject_CallOneArg(__pyx_t_3, __pyx_t_5); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;
__Pyx_GOTREF(__pyx_t_1);
} else {
- __pyx_t_2 = PyTuple_New(1+1); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_t_2 = PyTuple_New(1+1); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_2);
__Pyx_GIVEREF(__pyx_t_4); PyTuple_SET_ITEM(__pyx_t_2, 0, __pyx_t_4); __pyx_t_4 = NULL;
__Pyx_GIVEREF(__pyx_t_5);
PyTuple_SET_ITEM(__pyx_t_2, 0+1, __pyx_t_5);
__pyx_t_5 = 0;
- __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_3, __pyx_t_2, NULL); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_3, __pyx_t_2, NULL); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_1);
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
}
__Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
- if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_ptype_5numpy_ndarray))))) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_ptype_5numpy_ndarray))))) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__pyx_t_6 = ((PyArrayObject *)__pyx_t_1);
{
__Pyx_BufFmt_StackElem __pyx_stack[1];
if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_G.rcbuffer->pybuffer, (PyObject*)__pyx_t_6, &__Pyx_TypeInfo_double, PyBUF_FORMAT| PyBUF_C_CONTIGUOUS, 2, 0, __pyx_stack) == -1)) {
__pyx_v_G = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_G.rcbuffer->pybuffer.buf = NULL;
- {__pyx_filename = __pyx_f[0]; __pyx_lineno = 26; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ {__pyx_filename = __pyx_f[0]; __pyx_lineno = 58; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
} else {__pyx_pybuffernd_G.diminfo[0].strides = __pyx_pybuffernd_G.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_G.diminfo[0].shape = __pyx_pybuffernd_G.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_G.diminfo[1].strides = __pyx_pybuffernd_G.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_G.diminfo[1].shape = __pyx_pybuffernd_G.rcbuffer->pybuffer.shape[1];
}
}
@@ -1449,7 +1464,193 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
__pyx_v_G = ((PyArrayObject *)__pyx_t_1);
__pyx_t_1 = 0;
- /* "ot/emd/emd.pyx":29
+ /* "ot/emd/emd.pyx":60
+ * cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
+ *
+ * if not len(a): # <<<<<<<<<<<<<<
+ * a=np.ones((n1,))/n1
+ *
+ */
+ __pyx_t_7 = PyObject_Length(((PyObject *)__pyx_v_a)); if (unlikely(__pyx_t_7 == -1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 60; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_t_8 = ((!(__pyx_t_7 != 0)) != 0);
+ if (__pyx_t_8) {
+
+ /* "ot/emd/emd.pyx":61
+ *
+ * if not len(a):
+ * a=np.ones((n1,))/n1 # <<<<<<<<<<<<<<
+ *
+ * if not len(b):
+ */
+ __pyx_t_3 = __Pyx_GetModuleGlobalName(__pyx_n_s_np); if (unlikely(!__pyx_t_3)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_3);
+ __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_3, __pyx_n_s_ones); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_2);
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ __pyx_t_3 = __Pyx_PyInt_From_int(__pyx_v_n1); if (unlikely(!__pyx_t_3)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_3);
+ __pyx_t_5 = PyTuple_New(1); if (unlikely(!__pyx_t_5)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_5);
+ __Pyx_GIVEREF(__pyx_t_3);
+ PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_t_3);
+ __pyx_t_3 = 0;
+ __pyx_t_3 = NULL;
+ if (CYTHON_COMPILING_IN_CPYTHON && unlikely(PyMethod_Check(__pyx_t_2))) {
+ __pyx_t_3 = PyMethod_GET_SELF(__pyx_t_2);
+ if (likely(__pyx_t_3)) {
+ PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2);
+ __Pyx_INCREF(__pyx_t_3);
+ __Pyx_INCREF(function);
+ __Pyx_DECREF_SET(__pyx_t_2, function);
+ }
+ }
+ if (!__pyx_t_3) {
+ __pyx_t_1 = __Pyx_PyObject_CallOneArg(__pyx_t_2, __pyx_t_5); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;
+ __Pyx_GOTREF(__pyx_t_1);
+ } else {
+ __pyx_t_4 = PyTuple_New(1+1); if (unlikely(!__pyx_t_4)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_4);
+ __Pyx_GIVEREF(__pyx_t_3); PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_3); __pyx_t_3 = NULL;
+ __Pyx_GIVEREF(__pyx_t_5);
+ PyTuple_SET_ITEM(__pyx_t_4, 0+1, __pyx_t_5);
+ __pyx_t_5 = 0;
+ __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_4, NULL); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_1);
+ __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
+ }
+ __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
+ __pyx_t_2 = __Pyx_PyInt_From_int(__pyx_v_n1); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_2);
+ __pyx_t_4 = __Pyx_PyNumber_Divide(__pyx_t_1, __pyx_t_2); if (unlikely(!__pyx_t_4)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_4);
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
+ if (!(likely(((__pyx_t_4) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_4, __pyx_ptype_5numpy_ndarray))))) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_t_9 = ((PyArrayObject *)__pyx_t_4);
+ {
+ __Pyx_BufFmt_StackElem __pyx_stack[1];
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_a.rcbuffer->pybuffer);
+ __pyx_t_10 = __Pyx_GetBufferAndValidate(&__pyx_pybuffernd_a.rcbuffer->pybuffer, (PyObject*)__pyx_t_9, &__Pyx_TypeInfo_double, PyBUF_FORMAT| PyBUF_C_CONTIGUOUS, 1, 0, __pyx_stack);
+ if (unlikely(__pyx_t_10 < 0)) {
+ PyErr_Fetch(&__pyx_t_11, &__pyx_t_12, &__pyx_t_13);
+ if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_a.rcbuffer->pybuffer, (PyObject*)__pyx_v_a, &__Pyx_TypeInfo_double, PyBUF_FORMAT| PyBUF_C_CONTIGUOUS, 1, 0, __pyx_stack) == -1)) {
+ Py_XDECREF(__pyx_t_11); Py_XDECREF(__pyx_t_12); Py_XDECREF(__pyx_t_13);
+ __Pyx_RaiseBufferFallbackError();
+ } else {
+ PyErr_Restore(__pyx_t_11, __pyx_t_12, __pyx_t_13);
+ }
+ }
+ __pyx_pybuffernd_a.diminfo[0].strides = __pyx_pybuffernd_a.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_a.diminfo[0].shape = __pyx_pybuffernd_a.rcbuffer->pybuffer.shape[0];
+ if (unlikely(__pyx_t_10 < 0)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 61; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ }
+ __pyx_t_9 = 0;
+ __Pyx_DECREF_SET(__pyx_v_a, ((PyArrayObject *)__pyx_t_4));
+ __pyx_t_4 = 0;
+
+ /* "ot/emd/emd.pyx":60
+ * cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
+ *
+ * if not len(a): # <<<<<<<<<<<<<<
+ * a=np.ones((n1,))/n1
+ *
+ */
+ }
+
+ /* "ot/emd/emd.pyx":63
+ * a=np.ones((n1,))/n1
+ *
+ * if not len(b): # <<<<<<<<<<<<<<
+ * b=np.ones((n2,))/n2
+ *
+ */
+ __pyx_t_7 = PyObject_Length(((PyObject *)__pyx_v_b)); if (unlikely(__pyx_t_7 == -1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 63; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_t_8 = ((!(__pyx_t_7 != 0)) != 0);
+ if (__pyx_t_8) {
+
+ /* "ot/emd/emd.pyx":64
+ *
+ * if not len(b):
+ * b=np.ones((n2,))/n2 # <<<<<<<<<<<<<<
+ *
+ * # calling the function
+ */
+ __pyx_t_2 = __Pyx_GetModuleGlobalName(__pyx_n_s_np); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_2);
+ __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_ones); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_1);
+ __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
+ __pyx_t_2 = __Pyx_PyInt_From_int(__pyx_v_n2); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_2);
+ __pyx_t_5 = PyTuple_New(1); if (unlikely(!__pyx_t_5)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_5);
+ __Pyx_GIVEREF(__pyx_t_2);
+ PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_t_2);
+ __pyx_t_2 = 0;
+ __pyx_t_2 = NULL;
+ if (CYTHON_COMPILING_IN_CPYTHON && unlikely(PyMethod_Check(__pyx_t_1))) {
+ __pyx_t_2 = PyMethod_GET_SELF(__pyx_t_1);
+ if (likely(__pyx_t_2)) {
+ PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_1);
+ __Pyx_INCREF(__pyx_t_2);
+ __Pyx_INCREF(function);
+ __Pyx_DECREF_SET(__pyx_t_1, function);
+ }
+ }
+ if (!__pyx_t_2) {
+ __pyx_t_4 = __Pyx_PyObject_CallOneArg(__pyx_t_1, __pyx_t_5); if (unlikely(!__pyx_t_4)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;
+ __Pyx_GOTREF(__pyx_t_4);
+ } else {
+ __pyx_t_3 = PyTuple_New(1+1); if (unlikely(!__pyx_t_3)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_GIVEREF(__pyx_t_2); PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_2); __pyx_t_2 = NULL;
+ __Pyx_GIVEREF(__pyx_t_5);
+ PyTuple_SET_ITEM(__pyx_t_3, 0+1, __pyx_t_5);
+ __pyx_t_5 = 0;
+ __pyx_t_4 = __Pyx_PyObject_Call(__pyx_t_1, __pyx_t_3, NULL); if (unlikely(!__pyx_t_4)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_4);
+ __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+ }
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_n2); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_1);
+ __pyx_t_3 = __Pyx_PyNumber_Divide(__pyx_t_4, __pyx_t_1); if (unlikely(!__pyx_t_3)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __Pyx_GOTREF(__pyx_t_3);
+ __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
+ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+ if (!(likely(((__pyx_t_3) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_3, __pyx_ptype_5numpy_ndarray))))) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_t_14 = ((PyArrayObject *)__pyx_t_3);
+ {
+ __Pyx_BufFmt_StackElem __pyx_stack[1];
+ __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_b.rcbuffer->pybuffer);
+ __pyx_t_10 = __Pyx_GetBufferAndValidate(&__pyx_pybuffernd_b.rcbuffer->pybuffer, (PyObject*)__pyx_t_14, &__Pyx_TypeInfo_double, PyBUF_FORMAT| PyBUF_C_CONTIGUOUS, 1, 0, __pyx_stack);
+ if (unlikely(__pyx_t_10 < 0)) {
+ PyErr_Fetch(&__pyx_t_13, &__pyx_t_12, &__pyx_t_11);
+ if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_b.rcbuffer->pybuffer, (PyObject*)__pyx_v_b, &__Pyx_TypeInfo_double, PyBUF_FORMAT| PyBUF_C_CONTIGUOUS, 1, 0, __pyx_stack) == -1)) {
+ Py_XDECREF(__pyx_t_13); Py_XDECREF(__pyx_t_12); Py_XDECREF(__pyx_t_11);
+ __Pyx_RaiseBufferFallbackError();
+ } else {
+ PyErr_Restore(__pyx_t_13, __pyx_t_12, __pyx_t_11);
+ }
+ }
+ __pyx_pybuffernd_b.diminfo[0].strides = __pyx_pybuffernd_b.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_b.diminfo[0].shape = __pyx_pybuffernd_b.rcbuffer->pybuffer.shape[0];
+ if (unlikely(__pyx_t_10 < 0)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 64; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ }
+ __pyx_t_14 = 0;
+ __Pyx_DECREF_SET(__pyx_v_b, ((PyArrayObject *)__pyx_t_3));
+ __pyx_t_3 = 0;
+
+ /* "ot/emd/emd.pyx":63
+ * a=np.ones((n1,))/n1
+ *
+ * if not len(b): # <<<<<<<<<<<<<<
+ * b=np.ones((n2,))/n2
+ *
+ */
+ }
+
+ /* "ot/emd/emd.pyx":67
*
* # calling the function
* EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost) # <<<<<<<<<<<<<<
@@ -1458,7 +1659,7 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
*/
EMD_wrap(__pyx_v_n1, __pyx_v_n2, ((double *)__pyx_v_a->data), ((double *)__pyx_v_b->data), ((double *)__pyx_v_M->data), ((double *)__pyx_v_G->data), ((double *)(&__pyx_v_cost)));
- /* "ot/emd/emd.pyx":31
+ /* "ot/emd/emd.pyx":69
* EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost)
*
* return G # <<<<<<<<<<<<<<
@@ -1472,8 +1673,8 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
* @cython.boundscheck(False)
* @cython.wraparound(False)
* def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<<
- * cdef int n1= M.shape[0]
- * cdef int n2= M.shape[1]
+ * """
+ * Solves the Earth Movers distance problem and returns the optimal transport matrix
*/
/* function exit code */
@@ -1500,6 +1701,8 @@ static PyObject *__pyx_pf_2ot_3emd_3emd_emd(CYTHON_UNUSED PyObject *__pyx_self,
__Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_b.rcbuffer->pybuffer);
__pyx_L2:;
__Pyx_XDECREF((PyObject *)__pyx_v_G);
+ __Pyx_XDECREF((PyObject *)__pyx_v_a);
+ __Pyx_XDECREF((PyObject *)__pyx_v_b);
__Pyx_XGIVEREF(__pyx_r);
__Pyx_RefNannyFinishContext();
return __pyx_r;
@@ -3691,6 +3894,7 @@ static __Pyx_StringTabEntry __pyx_string_tab[] = {
{&__pyx_kp_u_ndarray_is_not_Fortran_contiguou, __pyx_k_ndarray_is_not_Fortran_contiguou, sizeof(__pyx_k_ndarray_is_not_Fortran_contiguou), 0, 1, 0, 0},
{&__pyx_n_s_np, __pyx_k_np, sizeof(__pyx_k_np), 0, 0, 1, 1},
{&__pyx_n_s_numpy, __pyx_k_numpy, sizeof(__pyx_k_numpy), 0, 0, 1, 1},
+ {&__pyx_n_s_ones, __pyx_k_ones, sizeof(__pyx_k_ones), 0, 0, 1, 1},
{&__pyx_n_s_ot_emd_emd, __pyx_k_ot_emd_emd, sizeof(__pyx_k_ot_emd_emd), 0, 0, 1, 1},
{&__pyx_n_s_range, __pyx_k_range, sizeof(__pyx_k_range), 0, 0, 1, 1},
{&__pyx_n_s_test, __pyx_k_test, sizeof(__pyx_k_test), 0, 0, 1, 1},
@@ -3781,8 +3985,8 @@ static int __Pyx_InitCachedConstants(void) {
* @cython.boundscheck(False)
* @cython.wraparound(False)
* def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<<
- * cdef int n1= M.shape[0]
- * cdef int n2= M.shape[1]
+ * """
+ * Solves the Earth Movers distance problem and returns the optimal transport matrix
*/
__pyx_tuple__7 = PyTuple_Pack(7, __pyx_n_s_a, __pyx_n_s_b, __pyx_n_s_M, __pyx_n_s_n1, __pyx_n_s_n2, __pyx_n_s_cost, __pyx_n_s_G); if (unlikely(!__pyx_tuple__7)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_tuple__7);
@@ -3924,8 +4128,8 @@ PyMODINIT_FUNC PyInit_emd(void)
* @cython.boundscheck(False)
* @cython.wraparound(False)
* def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): # <<<<<<<<<<<<<<
- * cdef int n1= M.shape[0]
- * cdef int n2= M.shape[1]
+ * """
+ * Solves the Earth Movers distance problem and returns the optimal transport matrix
*/
__pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_2ot_3emd_3emd_1emd, NULL, __pyx_n_s_ot_emd_emd); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 21; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_1);
@@ -4817,6 +5021,11 @@ static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) {
return 0;
}
+static void __Pyx_RaiseBufferFallbackError(void) {
+ PyErr_SetString(PyExc_ValueError,
+ "Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!");
+}
+
static CYTHON_INLINE void __Pyx_ErrRestore(PyObject *type, PyObject *value, PyObject *tb) {
#if CYTHON_COMPILING_IN_CPYTHON
PyObject *tmp_type, *tmp_value, *tmp_tb;
diff --git a/ot/emd/emd.pyx b/ot/emd/emd.pyx
index e5ac8e0..753b195 100644
--- a/ot/emd/emd.pyx
+++ b/ot/emd/emd.pyx
@@ -22,17 +22,35 @@ def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode=
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
+ gamm=emd(a,b,M)
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F
-
- :param a: m weights of the source distribution (must sum to one)
- :param b: n weights of the target distribution (must sum to one)
- :param M: m x n cost matrix
- :type a: np.ndarray
- :type b: np.ndarray
- :type M: np.ndarray
- :return: Optimal transport matrix
- :rtype: np.ndarray
-
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the metric cost matrix
+ - a and b are the sample weights
+
+ Parameters
+ ----------
+ a : (ns,) ndarray
+ samples in the source domain (uniform waigth if empty)
+ b : (nt,) ndarray
+ samples in the target domain (uniform waigth if empty)
+ M : (ns,nt) ndarray
+ loss matrix
+
+
+ Returns
+ -------
+ gamma: (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
"""
cdef int n1= M.shape[0]
@@ -40,6 +58,12 @@ def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode=
cdef float cost=0
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
+
+ if not len(a):
+ a=np.ones((n1,))/n1
+
+ if not len(b):
+ b=np.ones((n2,))/n2
# calling the function
EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost)
diff --git a/ot/utils.py b/ot/utils.py
index 1a1c6b8..582c3ff 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -1,14 +1,23 @@
import numpy as np
-from scipy.spatial.distance import cdist, pdist
+from scipy.spatial.distance import cdist
def dist(x1,x2=None,metric='sqeuclidean'):
"""Compute distance between samples in x1 and x2"""
if x2 is None:
- return pdist(x1,metric=metric)
+ return cdist(x1,x1,metric=metric)
else:
return cdist(x1,x2,metric=metric)
+
+def dist0(n,method='linear'):
+ """Compute stardard cos matrices for OT problems"""
+ res=0
+ if method=='linear':
+ x=np.arange(n,dtype=np.float64).reshape((n,1))
+ res=dist(x,x)
+ return res
+
def dots(*args):
""" Stupid but nice dots function for multiple matrix multiply """