summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAntoine Rolet <antoine.rolet@gmail.com>2017-09-07 13:50:41 +0900
committerAntoine Rolet <antoine.rolet@gmail.com>2017-09-07 13:50:41 +0900
commit12d9b3ff72e9669ccc0162e82b7a33beb51d3e25 (patch)
tree72a2908e9d0e67f7e8499d7ef9aca1246528a980
parentf8c1c8740f9974dcf4aaf191851d62149dceb91c (diff)
Return dual variables in an optional dictionary
Also removed some code duplication
-rw-r--r--ot/lp/__init__.py24
-rw-r--r--ot/lp/emd_wrap.pyx69
-rw-r--r--test/test_ot.py20
3 files changed, 25 insertions, 88 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 6048f60..c15e6b9 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -9,12 +9,12 @@ Solvers for the original linear program OT problem
import numpy as np
# import compiled emd
-from .emd_wrap import emd_c, emd2_c
+from .emd_wrap import emd_c
from ..utils import parmap
import multiprocessing
-def emd(a, b, M, numItermax=100000, dual_variables=False):
+def emd(a, b, M, numItermax=100000, log=False):
"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -42,11 +42,17 @@ def emd(a, b, M, numItermax=100000, dual_variables=False):
numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost and dual
+ variables. Otherwise returns only the optimal transportation matrix.
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is true, a dictionary containing the cost and dual
+ variables
Examples
@@ -86,9 +92,13 @@ def emd(a, b, M, numItermax=100000, dual_variables=False):
if len(b) == 0:
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
- G, alpha, beta = emd_c(a, b, M, numItermax)
- if dual_variables:
- return G, alpha, beta
+ G, cost, u, v = emd_c(a, b, M, numItermax)
+ if log:
+ log = {}
+ log['cost'] = cost
+ log['u'] = u
+ log['v'] = v
+ return G, log
return G
def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
@@ -163,11 +173,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
if len(b.shape)==1:
- return emd2_c(a, b, M, numItermax)[0]
+ return emd_c(a, b, M, numItermax)[1]
nb = b.shape[1]
# res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
def f(b):
- return emd2_c(a,b,M, numItermax)[0]
+ return emd_c(a,b,M, numItermax)[1]
res= parmap(f, [b[:,i] for i in range(nb)],processes)
return np.array(res)
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 9bea154..5618dfc 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -86,71 +86,4 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
elif resultSolver == MAX_ITER_REACHED:
warnings.warn("numItermax reached before optimality. Try to increase numItermax.")
- return G, alpha, beta
-
-@cython.boundscheck(False)
-@cython.wraparound(False)
-def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int numItermax):
- """
- Solves the Earth Movers distance problem and returns the optimal transport loss
-
- gamm=emd(a,b,M)
-
- .. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
-
- s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
- \gamma\geq 0
- where :
-
- - M is the metric cost matrix
- - a and b are the sample weights
-
- Parameters
- ----------
- a : (ns,) ndarray, float64
- source histogram
- b : (nt,) ndarray, float64
- target histogram
- M : (ns,nt) ndarray, float64
- loss matrix
- numItermax : int
- The maximum number of iterations before stopping the optimization
- algorithm if it has not converged.
-
-
- Returns
- -------
- gamma: (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
-
- """
- cdef int n1= M.shape[0]
- cdef int n2= M.shape[1]
-
- cdef double cost=0
- cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
-
- cdef np.ndarray[double, ndim = 1, mode = "c"] alpha = np.zeros([n1])
- cdef np.ndarray[double, ndim = 1, mode = "c"] beta = np.zeros([n2])
-
- if not len(a):
- a=np.ones((n1,))/n1
-
- if not len(b):
- b=np.ones((n2,))/n2
- # calling the function
- cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, numItermax)
- if resultSolver != OPTIMAL:
- if resultSolver == INFEASIBLE:
- warnings.warn("Problem infeasible. Check that a and b are in the simplex")
- elif resultSolver == UNBOUNDED:
- warnings.warn("Problem unbounded")
- elif resultSolver == MAX_ITER_REACHED:
- warnings.warn("numItermax reached before optimality. Try to increase numItermax.")
-
- return cost, alpha, beta
-
+ return G, cost, alpha, beta
diff --git a/test/test_ot.py b/test/test_ot.py
index 8a19cf6..78f64ab 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -124,27 +124,26 @@ def test_warnings():
# %%
print('Computing {} EMD '.format(1))
- G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
# Trigger a warning.
print('Computing {} EMD '.format(1))
- G, alpha, beta = ot.emd(a, b, M, dual_variables=True, numItermax=1)
+ G = ot.emd(a, b, M, numItermax=1)
# Verify some things
assert "numItermax" in str(w[-1].message)
assert len(w) == 1
# Trigger a warning.
a[0]=100
print('Computing {} EMD '.format(2))
- G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
+ G = ot.emd(a, b, M)
# Verify some things
assert "infeasible" in str(w[-1].message)
assert len(w) == 2
# Trigger a warning.
a[0]=-1
print('Computing {} EMD '.format(2))
- G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
+ G = ot.emd(a, b, M)
# Verify some things
assert "infeasible" in str(w[-1].message)
assert len(w) == 3
@@ -176,16 +175,11 @@ def test_dual_variables():
# emd loss 1 proc
ot.tic()
- G, alpha, beta = ot.emd(a, b, M, dual_variables=True)
+ G, log = ot.emd(a, b, M, log=True)
ot.toc('1 proc : {} s')
cost1 = (G * M).sum()
- cost_dual = np.vdot(a, alpha) + np.vdot(b, beta)
-
- # emd loss 1 proc
- ot.tic()
- cost_emd2 = ot.emd2(a, b, M)
- ot.toc('1 proc : {} s')
+ cost_dual = np.vdot(a, log['u']) + np.vdot(b, log['v'])
ot.tic()
G2 = ot.emd(b, a, np.ascontiguousarray(M.T))
@@ -194,7 +188,7 @@ def test_dual_variables():
cost2 = (G2 * M.T).sum()
# Check that both cost computations are equivalent
- np.testing.assert_almost_equal(cost1, cost_emd2)
+ np.testing.assert_almost_equal(cost1, log['cost'])
# Check that dual and primal cost are equal
np.testing.assert_almost_equal(cost1, cost_dual)
# Check symmetry
@@ -205,5 +199,5 @@ def test_dual_variables():
[ind1, ind2] = np.nonzero(G)
# Check that reduced cost is zero on transport arcs
- np.testing.assert_array_almost_equal((M - alpha.reshape(-1, 1) - beta.reshape(1, -1))[ind1, ind2],
+ np.testing.assert_array_almost_equal((M - log['u'].reshape(-1, 1) - log['v'].reshape(1, -1))[ind1, ind2],
np.zeros(ind1.size))