summaryrefslogtreecommitdiff
path: root/ot/lp/EMD_wrap.cpp
blob: 52cd2629d23deb55081dbcbbf0d80a8b076ba82f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
/* This file is a c++ wrapper function for computing the transportation cost
 * between two vectors given a cost matrix.
 *
 * It was written by Antoine Rolet (2014) and mainly consists of a wrapper
 * of the code written by Nicolas Bonneel available on this page
 *          http://people.seas.harvard.edu/~nbonneel/FastTransport/
 *
 * It was then modified to make it more amenable to python inline calling
 *
 * Please give relevant credit to the original author (Nicolas Bonneel) if
 * you use this code for a publication.
 *
 */

#include "EMD.h"


void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost)  {
// beware M and C anre strored in row major C style!!!
  int n, m, i,cur;
  double  max,max_iter;


    typedef FullBipartiteDigraph Digraph;
  DIGRAPH_TYPEDEFS(FullBipartiteDigraph);

  // Get the number of non zero coordinates for r and c
    n=0;
    for (node_id_type i=0; i<n1; i++) {
        double val=*(X+i);
        if (val>0) {
            n++;
        }
    }
    m=0;
    for (node_id_type i=0; i<n2; i++) {
        double val=*(Y+i);
        if (val>0) {
            m++;
        }
    }


    // Define the graph

    std::vector<int> indI(n), indJ(m);
    std::vector<double> weights1(n), weights2(m);
    Digraph di(n, m);
    NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m,max_iter);

    // Set supply and demand, don't account for 0 values (faster)

    max=0;
    cur=0;
    for (node_id_type i=0; i<n1; i++) {
        double val=*(X+i);
        if (val>0) {
            weights1[ di.nodeFromId(cur) ] = val;
            max+=val;
            indI[cur++]=i;
        }
    }

    // Demand is actually negative supply...

    max=0;
    cur=0;
    for (node_id_type i=0; i<n2; i++) {
        double val=*(Y+i);
        if (val>0) {
            weights2[ di.nodeFromId(cur) ] = -val;
            indJ[cur++]=i;

            max-=val;
        }
    }


    net.supplyMap(&weights1[0], n, &weights2[0], m);

    // Set the cost of each edge
    max=0;
    for (node_id_type i=0; i<n; i++) {
        for (node_id_type j=0; j<m; j++) {
            double val=*(D+indI[i]*n2+indJ[j]);
            net.setCost(di.arcFromId(i*m+j), val);
            if (val>max) {
                max=val;
            }
        }
    }


    // Solve the problem with the network simplex algorithm

    int ret=net.run();
    if (ret!=(int)net.OPTIMAL) {
        if (ret==(int)net.INFEASIBLE) {
            std::cout << "Infeasible problem";
        }
        if (ret==(int)net.UNBOUNDED)
        {
            std::cout << "Unbounded problem";
        }
    } else
    {
        for (node_id_type i=0; i<n; i++)
        {
            for (node_id_type j=0; j<m; j++)
            {
                *(G+indI[i]*n2+indJ[j]) = net.flow(di.arcFromId(i*m+j));
            }
        };
        *cost = net.totalCost();

    };



}