summaryrefslogtreecommitdiff
path: root/ot/emd/emd.pyx
blob: d090cea2891fd2d6a504e73868f5886f5bfe6a5e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 11 08:42:08 2014

@author: rflamary
"""
import numpy as np
cimport numpy as np

cimport cython



cdef extern from "EMD.h":
    void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost)



@cython.boundscheck(False)
@cython.wraparound(False)
def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"]  b,np.ndarray[double, ndim=2, mode="c"]  M):
    cdef int n1= M.shape[0]
    cdef int n2= M.shape[1]

    cdef float cost=0
    cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])

    # calling the function
    EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost)

    return G