summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2020-01-27 10:40:05 +0100
committerRémi Flamary <remi.flamary@gmail.com>2020-01-27 10:40:05 +0100
commite5196fa7a8c493b831fd5dac52a89bbf29e7b0e6 (patch)
treebb24f3e79bf5d1f57f378299a3a147cd7db06c7a /ot/lp
parent30fc233f7f62d571a562971a945d68c3782f0780 (diff)
correct bug in emd emd2 still todo
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/__init__.py194
-rw-r--r--ot/lp/emd_wrap.pyx2
2 files changed, 174 insertions, 22 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index eabdd3a..a771ce4 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -23,11 +23,150 @@ from ..utils import parmap
from .cvx import barycenter
from ..utils import dist
-__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
- 'emd_1d', 'emd2_1d', 'wasserstein_1d']
+__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
+ 'emd_1d', 'emd2_1d', 'wasserstein_1d']
-def emd(a, b, M, numItermax=100000, log=False, dense=True):
+def center_ot_dual(alpha0, beta0, a=None, b=None):
+ r"""Center dual OT potentials wrt theirs weights
+
+ The main idea of this function is to find unique dual potentials
+ that ensure some kind of centering/fairness. It will help have
+ stability when multiple calling of the OT solver with small changes.
+
+ Basically we add another constraint to the potential that will not
+ change the objective value but will ensure unicity. The constraint
+ is the following:
+
+ .. math::
+ \alpha^T a= \beta^T b
+
+ in addition to the OT problem constraints.
+
+ since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing
+ a constant from both :math:`\alpha_0` and :math:`\beta_0`.
+
+ .. math::
+ c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta}
+
+ \alpha=\alpha_0+c
+
+ \beta=\beta0+c
+
+ Parameters
+ ----------
+ alpha0 : (ns,) numpy.ndarray, float64
+ Source dual potential
+ beta0 : (nt,) numpy.ndarray, float64
+ Target dual potential
+ a : (ns,) numpy.ndarray, float64
+ Source histogram (uniform weight if empty list)
+ b : (nt,) numpy.ndarray, float64
+ Target histogram (uniform weight if empty list)
+
+ Returns
+ -------
+ alpha : (ns,) numpy.ndarray, float64
+ Source centered dual potential
+ beta : (nt,) numpy.ndarray, float64
+ Target centered dual potential
+
+ """
+ # if no weights are provided, use uniform
+ if a is None:
+ a = np.ones(alpha0.shape[0]) / alpha0.shape[0]
+ if b is None:
+ b = np.ones(beta0.shape[0]) / beta0.shape[0]
+
+ # compute constant that balances the weighted sums of the duals
+ c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum())
+
+ # update duals
+ alpha = alpha0 + c
+ beta = beta0 - c
+
+ return alpha, beta
+
+
+def estimate_dual_null_weights(alpha0, beta0, a, b, M):
+ r"""Estimate feasible values for 0-weighted dual potentials
+
+ The feasible values are computed efficiently bjt rather coarsely.
+ First we compute the constraints violations:
+
+ .. math::
+ V=\alpha+\beta^T-M
+
+ Next we compute the max amount of violation per row (alpha) and
+ columns (beta)
+
+ .. math::
+ v^a_i=\max_j V_{i,j}
+
+ v^b_j=\max_i V_{i,j}
+
+ Finally we update the dual potential with 0 weights if a
+ constraint is violated
+
+ .. math::
+ \alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0
+
+ \beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0
+
+ In the end the dual potential are centred using function
+ :ref:`center_ot_dual`.
+
+ Note that all those updates do not change the objective value of the
+ solution but provide dual potential that do not violate the constraints.
+
+ Parameters
+ ----------
+ alpha0 : (ns,) numpy.ndarray, float64
+ Source dual potential
+ beta0 : (nt,) numpy.ndarray, float64
+ Target dual potential
+ alpha0 : (ns,) numpy.ndarray, float64
+ Source dual potential
+ beta0 : (nt,) numpy.ndarray, float64
+ Target dual potential
+ a : (ns,) numpy.ndarray, float64
+ Source histogram (uniform weight if empty list)
+ b : (nt,) numpy.ndarray, float64
+ Target histogram (uniform weight if empty list)
+ M : (ns,nt) numpy.ndarray, float64
+ Loss matrix (c-order array with type float64)
+
+ Returns
+ -------
+ alpha : (ns,) numpy.ndarray, float64
+ Source corrected dual potential
+ beta : (nt,) numpy.ndarray, float64
+ Target corrected dual potential
+
+ """
+
+ # binary indexing of non-zeros weights
+ asel = a != 0
+ bsel = b != 0
+
+ # compute dual constraints violation
+ Viol = alpha0[:, None] + beta0[None, :] - M
+
+ # Compute worst violation per line and columns
+ aviol = np.max(Viol, 1)
+ bviol = np.max(Viol, 0)
+
+ # update corrects violation of
+ alpha_up = -1 * ~asel * np.maximum(aviol, 0)
+ beta_up = -1 * ~bsel * np.maximum(bviol, 0)
+
+ alpha = alpha0 + alpha_up
+ beta = beta0 + beta_up
+
+ return center_ot_dual(alpha, beta, a, b)
+
+
+def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
r"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -43,7 +182,7 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
- a and b are the sample weights
.. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
+ Note that the M matrix needs to be a C-order numpy.array in float64
format.
Uses the algorithm proposed in [1]_
@@ -66,6 +205,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
Otherwise returns a sparse representation using scipy's `coo_matrix`
format.
+ center_dual: boolean, optional (default=True)
+ If True, centers the dual potential using function
+ :ref:`center_ot_dual`.
Returns
-------
@@ -107,7 +249,6 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
-
# if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
@@ -117,11 +258,21 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
+ asel = a != 0
+ bsel = b != 0
+
if dense:
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
+
+ if np.any(~asel) or np.any(~bsel):
+ u, v = estimate_dual_null_weights(u, v, a, b, M)
+
else:
- Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
- G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+ Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
+ G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+
+ if np.any(~asel) or np.any(~bsel):
+ u, v = estimate_dual_null_weights(u, v, a, b, M)
result_code_string = check_result(result_code)
if log:
@@ -151,7 +302,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
- a and b are the sample weights
.. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
+ Note that the M matrix needs to be a C-order numpy.array in float64
format.
Uses the algorithm proposed in [1]_
@@ -177,7 +328,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
dense: boolean, optional (default=True)
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
Otherwise returns a sparse representation using scipy's `coo_matrix`
- format.
+ format.
Returns
-------
@@ -221,7 +372,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
# problem with pikling Forks
if sys.platform.endswith('win32'):
- processes=1
+ processes = 1
# if empty array given then use uniform distributions
if len(a) == 0:
@@ -235,10 +386,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if log or return_matrix:
def f(b):
if dense:
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
else:
- Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
- G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+ Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
+ G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
result_code_string = check_result(result_code)
log = {}
@@ -252,10 +403,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
else:
def f(b):
if dense:
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
else:
- Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
- G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
+ Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
+ G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
result_code_string = check_result(result_code)
check_result(result_code)
@@ -265,7 +416,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
return f(b)
nb = b.shape[1]
- if processes>1:
+ if processes > 1:
res = parmap(f, [b[:, i] for i in range(nb)], processes)
else:
res = list(map(f, [b[:, i].copy() for i in range(nb)]))
@@ -273,7 +424,6 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
return res
-
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
@@ -326,7 +476,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
k = X_init.shape[0]
d = X_init.shape[1]
if b is None:
- b = np.ones((k,))/k
+ b = np.ones((k,)) / k
if weights is None:
weights = np.ones((N,)) / N
@@ -337,7 +487,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
displacement_square_norm = stopThr + 1.
- while ( displacement_square_norm > stopThr and iter_count < numItermax ):
+ while (displacement_square_norm > stopThr and iter_count < numItermax):
T_sum = np.zeros((k, d))
@@ -347,7 +497,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
T_i = emd(b, measure_weights_i, M_i)
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
- displacement_square_norm = np.sum(np.square(T_sum-X))
+ displacement_square_norm = np.sum(np.square(T_sum - X))
if log:
displacement_square_norms.append(displacement_square_norm)
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index c0d7128..a4987f4 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -40,6 +40,8 @@ def check_result(result_code):
return message
+
+
@cython.boundscheck(False)
@cython.wraparound(False)
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint dense):