summaryrefslogtreecommitdiff
path: root/ot/emd/emd.pyx
blob: 753b195a672176532542a811e9e0d6a2c68e3505 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 11 08:42:08 2014

@author: rflamary
"""
import numpy as np
cimport numpy as np

cimport cython



cdef extern from "EMD.h":
    void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost)



@cython.boundscheck(False)
@cython.wraparound(False)
def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"]  b,np.ndarray[double, ndim=2, mode="c"]  M):
    """
        Solves the Earth Movers distance problem and returns the optimal transport matrix
        
        gamm=emd(a,b,M)
    
    .. math::
        \gamma = arg\min_\gamma <\gamma,M>_F 
        
        s.t. \gamma 1 = a
        
             \gamma^T 1= b 
             
             \gamma\geq 0
    where :
    
    - M is the metric cost matrix
    - a and b are the sample weights
             
    Parameters
    ----------
    a : (ns,) ndarray
        samples in the source domain (uniform waigth if empty)
    b : (nt,) ndarray
        samples in the target domain (uniform waigth if empty)
    M : (ns,nt) ndarray
        loss matrix        
  
    
    Returns
    -------
    gamma: (ns x nt) ndarray
        Optimal transportation matrix for the given parameters
 
    """
    cdef int n1= M.shape[0]
    cdef int n2= M.shape[1]

    cdef float cost=0
    cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
    
    if not len(a):
        a=np.ones((n1,))/n1

    if not len(b):
        b=np.ones((n2,))/n2

    # calling the function
    EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost)

    return G