summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
authorAntoine Rolet <antoine.rolet@gmail.com>2017-09-05 15:30:50 +0900
committerAntoine Rolet <antoine.rolet@gmail.com>2017-09-05 15:30:50 +0900
commit13dfb3ddbbd8926b4751b82dd41c5570253b1f07 (patch)
treeb28098e98640c64483a599103e2fdb5df46d2c79 /ot/lp
parent185eb3e2ef34b5ce6b8f90a28a5bcc78432b7fd3 (diff)
parent16697047eff9326a0ecb483317c13a854a3d3a71 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/EMD.h8
-rw-r--r--ot/lp/EMD_wrapper.cpp9
-rw-r--r--ot/lp/__init__.py48
-rw-r--r--ot/lp/emd_wrap.pyx101
4 files changed, 97 insertions, 69 deletions
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h
index fb7feca..15e9115 100644
--- a/ot/lp/EMD.h
+++ b/ot/lp/EMD.h
@@ -23,8 +23,12 @@
using namespace lemon;
typedef unsigned int node_id_type;
+enum ProblemType {
+ INFEASIBLE,
+ OPTIMAL,
+ UNBOUNDED
+};
-void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G,
- double* alpha, double* beta, double *cost, int max_iter);
+int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int max_iter);
#endif
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index 0977e75..8ac43c7 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -15,10 +15,11 @@
#include "EMD.h"
-void EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
+int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, int max_iter) {
// beware M and C anre strored in row major C style!!!
- int n, m, i,cur;
+ int n, m, i, cur;
+ double max;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
@@ -44,7 +45,7 @@ void EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
std::vector<int> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
- NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m,max_iter);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, max_iter);
// Set supply and demand, don't account for 0 values (faster)
@@ -108,5 +109,5 @@ void EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
};
-
+ return ret;
}
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 915a18c..a14d4e4 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -3,6 +3,10 @@
Solvers for the original linear program OT problem
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
# import compiled emd
from .emd_wrap import emd_c, emd2_c
@@ -10,8 +14,7 @@ from ..utils import parmap
import multiprocessing
-
-def emd(a, b, M, dual_variables=False, max_iter=-1):
+def emd(a, b, M, numItermax=100000, dual_variables=False):
"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -36,6 +39,9 @@ def emd(a, b, M, dual_variables=False, max_iter=-1):
Target histogram (uniform weigth if empty list)
M : (ns,nt) ndarray, float64
loss matrix
+ numItermax : int, optional (default=100000)
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
Returns
-------
@@ -48,7 +54,7 @@ def emd(a, b, M, dual_variables=False, max_iter=-1):
Simple example with obvious solution. The function emd accepts lists and
perform automatic conversion to numpy arrays
-
+
>>> import ot
>>> a=[.5,.5]
>>> b=[.5,.5]
@@ -80,13 +86,13 @@ def emd(a, b, M, dual_variables=False, max_iter=-1):
if len(b) == 0:
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
- G, alpha, beta = emd_c(a, b, M, max_iter)
+ G, alpha, beta = emd_c(a, b, M, numItermax)
if dual_variables:
return G, alpha, beta
return G
-def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1):
- """Solves the Earth Movers distance problem and returns the loss
+def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
+ """Solves the Earth Movers distance problem and returns the loss
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F
@@ -109,6 +115,9 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1):
Target histogram (uniform weigth if empty list)
M : (ns,nt) ndarray, float64
loss matrix
+ numItermax : int, optional (default=100000)
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
Returns
-------
@@ -121,15 +130,15 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1):
Simple example with obvious solution. The function emd accepts lists and
perform automatic conversion to numpy arrays
-
-
+
+
>>> import ot
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
>>> ot.emd2(a,b,M)
0.0
-
+
References
----------
@@ -152,16 +161,13 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count(), max_iter=-1):
a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
if len(b) == 0:
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
-
+
if len(b.shape)==1:
- return emd2_c(a, b, M, max_iter)[0]
- else:
- nb=b.shape[1]
- #res=[emd2_c(a,b[:,i].copy(),M) for i in range(nb)]
- def f(b):
- return emd2_c(a,b,M, max_iter)[0]
- res= parmap(f, [b[:,i] for i in range(nb)],processes)
- return np.array(res)
-
-
- \ No newline at end of file
+ return emd2_c(a, b, M, numItermax)[0]
+ nb = b.shape[1]
+ # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
+
+ def f(b):
+ return emd2_c(a,b,M, max_iter)[0]
+ res= parmap(f, [b[:,i] for i in range(nb)],processes)
+ return np.array(res)
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 4febb32..7056e0e 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -1,9 +1,12 @@
# -*- coding: utf-8 -*-
"""
-Created on Thu Sep 11 08:42:08 2014
-
-@author: rflamary
+Cython linker with C solver
"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
cimport numpy as np
@@ -12,47 +15,50 @@ cimport cython
cdef extern from "EMD.h":
- void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost,
- double* alpha, double* beta, int max_iter)
+ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int max_iter)
+ cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int maxiter):
+def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
-
+
gamm=emd(a,b,M)
-
+
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
-
+ \gamma = arg\min_\gamma <\gamma,M>_F
+
s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
+
+ \gamma^T 1= b
+
\gamma\geq 0
where :
-
+
- M is the metric cost matrix
- a and b are the sample weights
-
+
Parameters
----------
a : (ns,) ndarray, float64
- source histogram
+ source histogram
b : (nt,) ndarray, float64
target histogram
M : (ns,nt) ndarray, float64
- loss matrix
-
-
+ loss matrix
+ max_iter : int
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
+
+
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
-
+
"""
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]
@@ -61,7 +67,8 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1)
cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2)
-
+
+
if not len(a):
a=np.ones((n1,))/n1
@@ -69,47 +76,54 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
b=np.ones((n2,))/n2
# calling the function
- EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,
- <double*> alpha.data, <double*> beta.data, <double*> &cost, maxiter)
+ cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
+ if resultSolver != OPTIMAL:
+ if resultSolver == INFEASIBLE:
+ print("Problem infeasible. Try to increase numItermax.")
+ elif resultSolver == UNBOUNDED:
+ print("Problem unbounded")
return G, alpha, beta
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int maxiter):
+def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
"""
Solves the Earth Movers distance problem and returns the optimal transport loss
-
+
gamm=emd(a,b,M)
-
+
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
-
+ \gamma = arg\min_\gamma <\gamma,M>_F
+
s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
+
+ \gamma^T 1= b
+
\gamma\geq 0
where :
-
+
- M is the metric cost matrix
- a and b are the sample weights
-
+
Parameters
----------
a : (ns,) ndarray, float64
- source histogram
+ source histogram
b : (nt,) ndarray, float64
target histogram
M : (ns,nt) ndarray, float64
- loss matrix
-
-
+ loss matrix
+ max_iter : int
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
+
+
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
-
+
"""
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]
@@ -119,16 +133,19 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
cdef np.ndarray[double, ndim = 1, mode = "c"] alpha = np.zeros([n1])
cdef np.ndarray[double, ndim = 1, mode = "c"] beta = np.zeros([n2])
-
+
if not len(a):
a=np.ones((n1,))/n1
if not len(b):
b=np.ones((n2,))/n2
# calling the function
- EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,
- <double*> alpha.data, <double*> beta.data, <double*> &cost, maxiter)
-
+ cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
+ if resultSolver != OPTIMAL:
+ if resultSolver == INFEASIBLE:
+ print("Problem infeasible. Try to inscrease numItermax.")
+ elif resultSolver == UNBOUNDED:
+ print("Problem unbounded")
return cost, alpha, beta