summaryrefslogtreecommitdiff
path: root/ot/emd/emd.pyx
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-21 09:26:17 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-21 09:26:17 +0200
commit2109443f5bea396114d1f9e0563ba5c396378c57 (patch)
tree8717a7d7b04e3d862a5013cbdf2a9757e490dffa /ot/emd/emd.pyx
parent9816f9e29fd858602c2bd6d64deb8e157a9c3be2 (diff)
add emd
Diffstat (limited to 'ot/emd/emd.pyx')
-rw-r--r--ot/emd/emd.pyx31
1 files changed, 31 insertions, 0 deletions
diff --git a/ot/emd/emd.pyx b/ot/emd/emd.pyx
new file mode 100644
index 0000000..d090cea
--- /dev/null
+++ b/ot/emd/emd.pyx
@@ -0,0 +1,31 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Thu Sep 11 08:42:08 2014
+
+@author: rflamary
+"""
+import numpy as np
+cimport numpy as np
+
+cimport cython
+
+
+
+cdef extern from "EMD.h":
+ void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost)
+
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M):
+ cdef int n1= M.shape[0]
+ cdef int n2= M.shape[1]
+
+ cdef float cost=0
+ cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
+
+ # calling the function
+ EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost)
+
+ return G