summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
authorAntoine Rolet <antoine.rolet@gmail.com>2017-09-07 13:50:41 +0900
committerAntoine Rolet <antoine.rolet@gmail.com>2017-09-07 13:50:41 +0900
commit12d9b3ff72e9669ccc0162e82b7a33beb51d3e25 (patch)
tree72a2908e9d0e67f7e8499d7ef9aca1246528a980 /ot/lp
parentf8c1c8740f9974dcf4aaf191851d62149dceb91c (diff)
Return dual variables in an optional dictionary
Also removed some code duplication
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/__init__.py24
-rw-r--r--ot/lp/emd_wrap.pyx69
2 files changed, 18 insertions, 75 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 6048f60..c15e6b9 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -9,12 +9,12 @@ Solvers for the original linear program OT problem
import numpy as np
# import compiled emd
-from .emd_wrap import emd_c, emd2_c
+from .emd_wrap import emd_c
from ..utils import parmap
import multiprocessing
-def emd(a, b, M, numItermax=100000, dual_variables=False):
+def emd(a, b, M, numItermax=100000, log=False):
"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -42,11 +42,17 @@ def emd(a, b, M, numItermax=100000, dual_variables=False):
numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost and dual
+ variables. Otherwise returns only the optimal transportation matrix.
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is true, a dictionary containing the cost and dual
+ variables
Examples
@@ -86,9 +92,13 @@ def emd(a, b, M, numItermax=100000, dual_variables=False):
if len(b) == 0:
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
- G, alpha, beta = emd_c(a, b, M, numItermax)
- if dual_variables:
- return G, alpha, beta
+ G, cost, u, v = emd_c(a, b, M, numItermax)
+ if log:
+ log = {}
+ log['cost'] = cost
+ log['u'] = u
+ log['v'] = v
+ return G, log
return G
def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
@@ -163,11 +173,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
if len(b.shape)==1:
- return emd2_c(a, b, M, numItermax)[0]
+ return emd_c(a, b, M, numItermax)[1]
nb = b.shape[1]
# res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
def f(b):
- return emd2_c(a,b,M, numItermax)[0]
+ return emd_c(a,b,M, numItermax)[1]
res= parmap(f, [b[:,i] for i in range(nb)],processes)
return np.array(res)
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 9bea154..5618dfc 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -86,71 +86,4 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
elif resultSolver == MAX_ITER_REACHED:
warnings.warn("numItermax reached before optimality. Try to increase numItermax.")
- return G, alpha, beta
-
-@cython.boundscheck(False)
-@cython.wraparound(False)
-def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int numItermax):
- """
- Solves the Earth Movers distance problem and returns the optimal transport loss
-
- gamm=emd(a,b,M)
-
- .. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
-
- s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
- \gamma\geq 0
- where :
-
- - M is the metric cost matrix
- - a and b are the sample weights
-
- Parameters
- ----------
- a : (ns,) ndarray, float64
- source histogram
- b : (nt,) ndarray, float64
- target histogram
- M : (ns,nt) ndarray, float64
- loss matrix
- numItermax : int
- The maximum number of iterations before stopping the optimization
- algorithm if it has not converged.
-
-
- Returns
- -------
- gamma: (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
-
- """
- cdef int n1= M.shape[0]
- cdef int n2= M.shape[1]
-
- cdef double cost=0
- cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
-
- cdef np.ndarray[double, ndim = 1, mode = "c"] alpha = np.zeros([n1])
- cdef np.ndarray[double, ndim = 1, mode = "c"] beta = np.zeros([n2])
-
- if not len(a):
- a=np.ones((n1,))/n1
-
- if not len(b):
- b=np.ones((n2,))/n2
- # calling the function
- cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, numItermax)
- if resultSolver != OPTIMAL:
- if resultSolver == INFEASIBLE:
- warnings.warn("Problem infeasible. Check that a and b are in the simplex")
- elif resultSolver == UNBOUNDED:
- warnings.warn("Problem unbounded")
- elif resultSolver == MAX_ITER_REACHED:
- warnings.warn("numItermax reached before optimality. Try to increase numItermax.")
-
- return cost, alpha, beta
-
+ return G, cost, alpha, beta