summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
Diffstat (limited to 'ot')
-rw-r--r--ot/lp/EMD.h2
-rw-r--r--ot/lp/EMD_wrapper.cpp50
-rw-r--r--ot/lp/__init__.py29
-rw-r--r--ot/lp/emd_wrap.pyx26
-rw-r--r--ot/lp/network_simplex_simple.h55
5 files changed, 81 insertions, 81 deletions
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h
index aa92441..15e9115 100644
--- a/ot/lp/EMD.h
+++ b/ot/lp/EMD.h
@@ -29,6 +29,6 @@ enum ProblemType {
UNBOUNDED
};
-int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter);
+int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int max_iter);
#endif
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index c8c2eb3..8ac43c7 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -15,9 +15,10 @@
#include "EMD.h"
-int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) {
+int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
+ double* alpha, double* beta, double *cost, int max_iter) {
// beware M and C anre strored in row major C style!!!
- int n, m, i,cur;
+ int n, m, i, cur;
double max;
typedef FullBipartiteDigraph Digraph;
@@ -25,21 +26,20 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c
// Get the number of non zero coordinates for r and c
n=0;
- for (node_id_type i=0; i<n1; i++) {
+ for (int i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
n++;
}
}
m=0;
- for (node_id_type i=0; i<n2; i++) {
+ for (int i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
m++;
}
}
-
// Define the graph
std::vector<int> indI(n), indJ(m);
@@ -49,28 +49,23 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c
// Set supply and demand, don't account for 0 values (faster)
- max=0;
cur=0;
- for (node_id_type i=0; i<n1; i++) {
+ for (int i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
- weights1[ di.nodeFromId(cur) ] = val;
- max+=val;
+ weights1[ cur ] = val;
indI[cur++]=i;
}
}
// Demand is actually negative supply...
- max=0;
cur=0;
- for (node_id_type i=0; i<n2; i++) {
+ for (int i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
- weights2[ di.nodeFromId(cur) ] = -val;
+ weights2[ cur ] = -val;
indJ[cur++]=i;
-
- max-=val;
}
}
@@ -78,14 +73,10 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c
net.supplyMap(&weights1[0], n, &weights2[0], m);
// Set the cost of each edge
- max=0;
- for (node_id_type i=0; i<n; i++) {
- for (node_id_type j=0; j<m; j++) {
+ for (int i=0; i<n; i++) {
+ for (int j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(i*m+j), val);
- if (val>max) {
- max=val;
- }
}
}
@@ -103,14 +94,17 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c
}
} else
{
- for (node_id_type i=0; i<n; i++)
- {
- for (node_id_type j=0; j<m; j++)
- {
- *(G+indI[i]*n2+indJ[j]) = net.flow(di.arcFromId(i*m+j));
- }
- };
- *cost = net.totalCost();
+ *cost = 0;
+ Arc a; di.first(a);
+ for (; a != INVALID; di.next(a)) {
+ int i = di.source(a);
+ int j = di.target(a);
+ double flow = net.flow(a);
+ *cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
+ *(G+indI[i]*n2+indJ[j-n]) = flow;
+ *(alpha + indI[i]) = -net.potential(i);
+ *(beta + indJ[j-n]) = net.potential(j);
+ }
};
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index de91e74..a14d4e4 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -14,7 +14,7 @@ from ..utils import parmap
import multiprocessing
-def emd(a, b, M, numItermax=100000):
+def emd(a, b, M, numItermax=100000, dual_variables=False):
"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -86,8 +86,10 @@ def emd(a, b, M, numItermax=100000):
if len(b) == 0:
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
- return emd_c(a, b, M, numItermax)
-
+ G, alpha, beta = emd_c(a, b, M, numItermax)
+ if dual_variables:
+ return G, alpha, beta
+ return G
def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
"""Solves the Earth Movers distance problem and returns the loss
@@ -159,14 +161,13 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
if len(b) == 0:
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
-
- if len(b.shape) == 1:
- return emd2_c(a, b, M, numItermax)
- else:
- nb = b.shape[1]
- # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
-
- def f(b):
- return emd2_c(a, b, M, numItermax)
- res = parmap(f, [b[:, i] for i in range(nb)], processes)
- return np.array(res)
+
+ if len(b.shape)==1:
+ return emd2_c(a, b, M, numItermax)[0]
+ nb = b.shape[1]
+ # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
+
+ def f(b):
+ return emd2_c(a,b,M, max_iter)[0]
+ res= parmap(f, [b[:,i] for i in range(nb)],processes)
+ return np.array(res)
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 26d3330..7056e0e 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -15,7 +15,7 @@ cimport cython
cdef extern from "EMD.h":
- int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter)
+ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int max_iter)
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED
@@ -63,8 +63,11 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]
- cdef float cost=0
+ cdef double cost=0
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
+ cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1)
+ cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2)
+
if not len(a):
a=np.ones((n1,))/n1
@@ -73,14 +76,14 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
b=np.ones((n2,))/n2
# calling the function
- cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, max_iter)
+ cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
if resultSolver != OPTIMAL:
if resultSolver == INFEASIBLE:
print("Problem infeasible. Try to increase numItermax.")
elif resultSolver == UNBOUNDED:
print("Problem unbounded")
- return G
+ return G, alpha, beta
@cython.boundscheck(False)
@cython.wraparound(False)
@@ -125,27 +128,24 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]
- cdef float cost=0
+ cdef double cost=0
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
+ cdef np.ndarray[double, ndim = 1, mode = "c"] alpha = np.zeros([n1])
+ cdef np.ndarray[double, ndim = 1, mode = "c"] beta = np.zeros([n2])
+
if not len(a):
a=np.ones((n1,))/n1
if not len(b):
b=np.ones((n2,))/n2
-
# calling the function
- cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, max_iter)
+ cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
if resultSolver != OPTIMAL:
if resultSolver == INFEASIBLE:
print("Problem infeasible. Try to inscrease numItermax.")
elif resultSolver == UNBOUNDED:
print("Problem unbounded")
- cost=0
- for i in range(n1):
- for j in range(n2):
- cost+=G[i,j]*M[i,j]
-
- return cost
+ return cost, alpha, beta
diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h
index 64856a0..08449f6 100644
--- a/ot/lp/network_simplex_simple.h
+++ b/ot/lp/network_simplex_simple.h
@@ -28,6 +28,12 @@
#ifndef LEMON_NETWORK_SIMPLEX_SIMPLE_H
#define LEMON_NETWORK_SIMPLEX_SIMPLE_H
#define DEBUG_LVL 0
+
+#if DEBUG_LVL>0
+#include <iomanip>
+#endif
+
+
#define EPSILON 10*2.2204460492503131e-016
#define MAX_DEBUG_ITER 100000
@@ -220,7 +226,7 @@ namespace lemon {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
- NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,double maxiters) :
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,int maxiters) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits<Value>::max()),
@@ -278,7 +284,7 @@ namespace lemon {
private:
- double max_iter;
+ int max_iter;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
typedef std::vector<int> IntVector;
@@ -676,14 +682,12 @@ namespace lemon {
/// \see resetParams(), reset()
ProblemType run() {
#if DEBUG_LVL>0
- mexPrintf("OPTIMAL = %d\nINFEASIBLE = %d\nUNBOUNDED = %d\n",OPTIMAL,INFEASIBLE,UNBOUNDED);
- mexEvalString("drawnow;");
+ std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "nUNBOUNDED = " << UNBOUNDED << "\n";
#endif
if (!init()) return INFEASIBLE;
#if DEBUG_LVL>0
- mexPrintf("Init done, starting iterations\n");
- mexEvalString("drawnow;");
+ std::cout << "Init done, starting iterations\n";
#endif
return start();
}
@@ -1422,10 +1426,10 @@ namespace lemon {
//pivot.setDantzig(true);
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
- if(++iter_number>=max_iter&&max_iter>0){
+ if(max_iter > 0 && ++iter_number>=max_iter&&max_iter>0){
char errMess[1000];
- // sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher",iter_number );
- // mexWarnMsgTxt(errMess);
+ sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
+ std::cerr << errMess;
break;
}
#if DEBUG_LVL>0
@@ -1440,12 +1444,13 @@ namespace lemon {
for (int i=0; i<_flow.size(); i++) {
sumFlow+=_state[i]*_flow[i];
}
- mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\nReduced cost=%.30f\nPrecision =%.30f\n",sumFlow,niter, curCost,_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]), -EPSILON*(a));
- mexPrintf("Arc in = (%d,%d)\n",_node_id(_source[in_arc]),_node_id(_target[in_arc]));
- mexPrintf("Supplies = (%f,%f)\n",_supply[_source[in_arc]],_supply[_target[in_arc]]);
-
- mexPrintf("%.30f\n%.30f\n%.30f\n%.30f\n%",_cost[in_arc],_pi[_source[in_arc]],_pi[_target[in_arc]],a);
- mexEvalString("drawnow;");
+ std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
+ std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n";
+ std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n";
+ std::cout << _cost[in_arc] << "\n";
+ std::cout << _pi[_source[in_arc]] << "\n";
+ std::cout << _pi[_target[in_arc]] << "\n";
+ std::cout << a << "\n";
}
#endif
@@ -1459,11 +1464,11 @@ namespace lemon {
}
#if DEBUG_LVL>0
else{
- mexPrintf("No change\n");
+ std::cout << "No change\n";
}
#endif
#if DEBUG_LVL>1
- mexPrintf("Arc in = (%d,%d)\n",_source[in_arc],_target[in_arc]);
+ std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n";
#endif
}
@@ -1478,23 +1483,23 @@ namespace lemon {
for (int i=0; i<_flow.size(); i++) {
sumFlow+=_state[i]*_flow[i];
}
- mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\nReduced cost=%.30f\nPrecision =%.30f",sumFlow,niter, curCost,_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]), -EPSILON*(a));
- mexPrintf("Arc in = (%d,%d)\n",_node_id(_source[in_arc]),_node_id(_target[in_arc]));
- mexPrintf("Supplies = (%f,%f)\n",_supply[_source[in_arc]],_supply[_target[in_arc]]);
+
+ std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
+
+ std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n";
+ std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n";
- mexEvalString("drawnow;");
#endif
#if DEBUG_LVL>1
- double sumFlow=0;
+ sumFlow=0;
for (int i=0; i<_flow.size(); i++) {
sumFlow+=_state[i]*_flow[i];
if (_state[i]==STATE_TREE) {
- mexPrintf("Non zero value at (%d,%d)\n",_node_num+1-_source[i],_node_num+1-_target[i]);
+ std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n";
}
}
- mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\n",sumFlow,niter, totalCost());
- mexEvalString("drawnow;");
+ std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n";
#endif
// Check feasibility
for (int e = _search_arc_num; e != _all_arc_num; ++e) {