summaryrefslogtreecommitdiff
path: root/ot/lp/emd_wrap.pyx
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-03-21 09:32:06 +0100
committerRémi Flamary <remi.flamary@gmail.com>2017-03-21 09:32:06 +0100
commit0bbd36c0b569e58c3d47177e816b3541fac4b916 (patch)
tree2cba51d5447afa3356470371adb216b587832678 /ot/lp/emd_wrap.pyx
parenta84f2c3e23edd1fa89975bd77b08672f518d5ca4 (diff)
cleanupt cpp wrapper name
Diffstat (limited to 'ot/lp/emd_wrap.pyx')
-rw-r--r--ot/lp/emd_wrap.pyx131
1 files changed, 131 insertions, 0 deletions
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
new file mode 100644
index 0000000..46794ab
--- /dev/null
+++ b/ot/lp/emd_wrap.pyx
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Thu Sep 11 08:42:08 2014
+
+@author: rflamary
+"""
+import numpy as np
+cimport numpy as np
+
+cimport cython
+
+
+
+cdef extern from "EMD.h":
+ void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost)
+
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M):
+ """
+ Solves the Earth Movers distance problem and returns the optimal transport matrix
+
+ gamm=emd(a,b,M)
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the metric cost matrix
+ - a and b are the sample weights
+
+ Parameters
+ ----------
+ a : (ns,) ndarray, float64
+ source histogram
+ b : (nt,) ndarray, float64
+ target histogram
+ M : (ns,nt) ndarray, float64
+ loss matrix
+
+
+ Returns
+ -------
+ gamma: (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+
+ """
+ cdef int n1= M.shape[0]
+ cdef int n2= M.shape[1]
+
+ cdef float cost=0
+ cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
+
+ if not len(a):
+ a=np.ones((n1,))/n1
+
+ if not len(b):
+ b=np.ones((n2,))/n2
+
+ # calling the function
+ EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost)
+
+ return G
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M):
+ """
+ Solves the Earth Movers distance problem and returns the optimal transport loss
+
+ gamm=emd(a,b,M)
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the metric cost matrix
+ - a and b are the sample weights
+
+ Parameters
+ ----------
+ a : (ns,) ndarray, float64
+ source histogram
+ b : (nt,) ndarray, float64
+ target histogram
+ M : (ns,nt) ndarray, float64
+ loss matrix
+
+
+ Returns
+ -------
+ gamma: (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+
+ """
+ cdef int n1= M.shape[0]
+ cdef int n2= M.shape[1]
+
+ cdef float cost=0
+ cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
+
+ if not len(a):
+ a=np.ones((n1,))/n1
+
+ if not len(b):
+ b=np.ones((n2,))/n2
+
+ # calling the function
+ EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost)
+
+ cost=0
+ for i in range(n1):
+ for j in range(n2):
+ cost+=G[i,j]*M[i,j]
+
+ return cost
+