From e5196fa7a8c493b831fd5dac52a89bbf29e7b0e6 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 27 Jan 2020 10:40:05 +0100 Subject: correct bug in emd emd2 still todo --- ot/lp/__init__.py | 194 +++++++++++++++++++++++++++++++++++++++++++++++------ ot/lp/emd_wrap.pyx | 2 + 2 files changed, 174 insertions(+), 22 deletions(-) (limited to 'ot/lp') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index eabdd3a..a771ce4 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -23,11 +23,150 @@ from ..utils import parmap from .cvx import barycenter from ..utils import dist -__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', - 'emd_1d', 'emd2_1d', 'wasserstein_1d'] +__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', + 'emd_1d', 'emd2_1d', 'wasserstein_1d'] -def emd(a, b, M, numItermax=100000, log=False, dense=True): +def center_ot_dual(alpha0, beta0, a=None, b=None): + r"""Center dual OT potentials wrt theirs weights + + The main idea of this function is to find unique dual potentials + that ensure some kind of centering/fairness. It will help have + stability when multiple calling of the OT solver with small changes. + + Basically we add another constraint to the potential that will not + change the objective value but will ensure unicity. The constraint + is the following: + + .. math:: + \alpha^T a= \beta^T b + + in addition to the OT problem constraints. + + since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing + a constant from both :math:`\alpha_0` and :math:`\beta_0`. + + .. math:: + c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta} + + \alpha=\alpha_0+c + + \beta=\beta0+c + + Parameters + ---------- + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + a : (ns,) numpy.ndarray, float64 + Source histogram (uniform weight if empty list) + b : (nt,) numpy.ndarray, float64 + Target histogram (uniform weight if empty list) + + Returns + ------- + alpha : (ns,) numpy.ndarray, float64 + Source centered dual potential + beta : (nt,) numpy.ndarray, float64 + Target centered dual potential + + """ + # if no weights are provided, use uniform + if a is None: + a = np.ones(alpha0.shape[0]) / alpha0.shape[0] + if b is None: + b = np.ones(beta0.shape[0]) / beta0.shape[0] + + # compute constant that balances the weighted sums of the duals + c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum()) + + # update duals + alpha = alpha0 + c + beta = beta0 - c + + return alpha, beta + + +def estimate_dual_null_weights(alpha0, beta0, a, b, M): + r"""Estimate feasible values for 0-weighted dual potentials + + The feasible values are computed efficiently bjt rather coarsely. + First we compute the constraints violations: + + .. math:: + V=\alpha+\beta^T-M + + Next we compute the max amount of violation per row (alpha) and + columns (beta) + + .. math:: + v^a_i=\max_j V_{i,j} + + v^b_j=\max_i V_{i,j} + + Finally we update the dual potential with 0 weights if a + constraint is violated + + .. math:: + \alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0 + + \beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0 + + In the end the dual potential are centred using function + :ref:`center_ot_dual`. + + Note that all those updates do not change the objective value of the + solution but provide dual potential that do not violate the constraints. + + Parameters + ---------- + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + a : (ns,) numpy.ndarray, float64 + Source histogram (uniform weight if empty list) + b : (nt,) numpy.ndarray, float64 + Target histogram (uniform weight if empty list) + M : (ns,nt) numpy.ndarray, float64 + Loss matrix (c-order array with type float64) + + Returns + ------- + alpha : (ns,) numpy.ndarray, float64 + Source corrected dual potential + beta : (nt,) numpy.ndarray, float64 + Target corrected dual potential + + """ + + # binary indexing of non-zeros weights + asel = a != 0 + bsel = b != 0 + + # compute dual constraints violation + Viol = alpha0[:, None] + beta0[None, :] - M + + # Compute worst violation per line and columns + aviol = np.max(Viol, 1) + bviol = np.max(Viol, 0) + + # update corrects violation of + alpha_up = -1 * ~asel * np.maximum(aviol, 0) + beta_up = -1 * ~bsel * np.maximum(bviol, 0) + + alpha = alpha0 + alpha_up + beta = beta0 + beta_up + + return center_ot_dual(alpha, beta, a, b) + + +def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -43,7 +182,7 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True): - a and b are the sample weights .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 + Note that the M matrix needs to be a C-order numpy.array in float64 format. Uses the algorithm proposed in [1]_ @@ -66,6 +205,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True): If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy's `coo_matrix` format. + center_dual: boolean, optional (default=True) + If True, centers the dual potential using function + :ref:`center_ot_dual`. Returns ------- @@ -107,7 +249,6 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True): b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) - # if empty array given then use uniform distributions if len(a) == 0: a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] @@ -117,11 +258,21 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True): assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" + asel = a != 0 + bsel = b != 0 + if dense: - G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + else: - Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) - G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) result_code_string = check_result(result_code) if log: @@ -151,7 +302,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), - a and b are the sample weights .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 + Note that the M matrix needs to be a C-order numpy.array in float64 format. Uses the algorithm proposed in [1]_ @@ -177,7 +328,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), dense: boolean, optional (default=True) If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy's `coo_matrix` - format. + format. Returns ------- @@ -221,7 +372,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), # problem with pikling Forks if sys.platform.endswith('win32'): - processes=1 + processes = 1 # if empty array given then use uniform distributions if len(a) == 0: @@ -235,10 +386,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if log or return_matrix: def f(b): if dense: - G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) else: - Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) - G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) result_code_string = check_result(result_code) log = {} @@ -252,10 +403,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), else: def f(b): if dense: - G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) else: - Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) - G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) result_code_string = check_result(result_code) check_result(result_code) @@ -265,7 +416,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), return f(b) nb = b.shape[1] - if processes>1: + if processes > 1: res = parmap(f, [b[:, i] for i in range(nb)], processes) else: res = list(map(f, [b[:, i].copy() for i in range(nb)])) @@ -273,7 +424,6 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), return res - def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None): """ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) @@ -326,7 +476,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None k = X_init.shape[0] d = X_init.shape[1] if b is None: - b = np.ones((k,))/k + b = np.ones((k,)) / k if weights is None: weights = np.ones((N,)) / N @@ -337,7 +487,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None displacement_square_norm = stopThr + 1. - while ( displacement_square_norm > stopThr and iter_count < numItermax ): + while (displacement_square_norm > stopThr and iter_count < numItermax): T_sum = np.zeros((k, d)) @@ -347,7 +497,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None T_i = emd(b, measure_weights_i, M_i) T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) - displacement_square_norm = np.sum(np.square(T_sum-X)) + displacement_square_norm = np.sum(np.square(T_sum - X)) if log: displacement_square_norms.append(displacement_square_norm) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index c0d7128..a4987f4 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -40,6 +40,8 @@ def check_result(result_code): return message + + @cython.boundscheck(False) @cython.wraparound(False) def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint dense): -- cgit v1.2.3