diff options
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/EMD.h | 2 | ||||
-rw-r--r-- | ot/lp/EMD_wrapper.cpp | 9 | ||||
-rw-r--r-- | ot/lp/__init__.py | 247 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 27 | ||||
-rw-r--r-- | ot/lp/network_simplex_simple.h | 7 |
5 files changed, 250 insertions, 42 deletions
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index f42e222..c0fe7a3 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -32,4 +32,6 @@ enum ProblemType { int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter); + + #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index fc7ca63..bc873ed 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -17,13 +17,13 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) { -// beware M and C anre strored in row major C style!!! - int n, m, i, cur; + // beware M and C anre strored in row major C style!!! + int n, m, i, cur; typedef FullBipartiteDigraph Digraph; - DIGRAPH_TYPEDEFS(FullBipartiteDigraph); + DIGRAPH_TYPEDEFS(FullBipartiteDigraph); - // Get the number of non zero coordinates for r and c + // Get the number of non zero coordinates for r and c n=0; for (int i=0; i<n1; i++) { double val=*(X+i); @@ -105,3 +105,4 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, return ret; } + diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 0c92810..514a607 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -2,8 +2,6 @@ """ Solvers for the original linear program OT problem - - """ # Author: Remi Flamary <remi.flamary@unice.fr> @@ -12,22 +10,169 @@ Solvers for the original linear program OT problem import multiprocessing import sys + import numpy as np from scipy.sparse import coo_matrix -from .import cvx - +from . import cvx +from .cvx import barycenter # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from ..utils import parmap -from .cvx import barycenter from ..utils import dist +from ..utils import parmap + +__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', + 'emd_1d', 'emd2_1d', 'wasserstein_1d'] + + +def center_ot_dual(alpha0, beta0, a=None, b=None): + r"""Center dual OT potentials w.r.t. theirs weights + + The main idea of this function is to find unique dual potentials + that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having + stability when multiple calling of the OT solver with small changes. + + Basically we add another constraint to the potential that will not + change the objective value but will ensure unicity. The constraint + is the following: + + .. math:: + \alpha^T a= \beta^T b + + in addition to the OT problem constraints. + + since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing + a constant from both :math:`\alpha_0` and :math:`\beta_0`. + + .. math:: + c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta} + + \alpha=\alpha_0+c + + \beta=\beta0+c + + Parameters + ---------- + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + a : (ns,) numpy.ndarray, float64 + Source histogram (uniform weight if empty list) + b : (nt,) numpy.ndarray, float64 + Target histogram (uniform weight if empty list) + + Returns + ------- + alpha : (ns,) numpy.ndarray, float64 + Source centered dual potential + beta : (nt,) numpy.ndarray, float64 + Target centered dual potential + + """ + # if no weights are provided, use uniform + if a is None: + a = np.ones(alpha0.shape[0]) / alpha0.shape[0] + if b is None: + b = np.ones(beta0.shape[0]) / beta0.shape[0] + + # compute constant that balances the weighted sums of the duals + c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum()) + + # update duals + alpha = alpha0 + c + beta = beta0 - c + + return alpha, beta + + +def estimate_dual_null_weights(alpha0, beta0, a, b, M): + r"""Estimate feasible values for 0-weighted dual potentials + + The feasible values are computed efficiently but rather coarsely. + + .. warning:: + This function is necessary because the C++ solver in emd_c + discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport + matrix) is exact, the solver only returns feasible dual potentials + on the samples with weights different from zero. + + First we compute the constraints violations: + + .. math:: + V=\alpha+\beta^T-M + + Next we compute the max amount of violation per row (alpha) and + columns (beta) -__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', - 'emd_1d', 'emd2_1d', 'wasserstein_1d'] + .. math:: + v^a_i=\max_j V_{i,j} + + v^b_j=\max_i V_{i,j} + + Finally we update the dual potential with 0 weights if a + constraint is violated + + .. math:: + \alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0 + + \beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0 + + In the end the dual potentials are centered using function + :ref:`center_ot_dual`. + + Note that all those updates do not change the objective value of the + solution but provide dual potentials that do not violate the constraints. + + Parameters + ---------- + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + a : (ns,) numpy.ndarray, float64 + Source distribution (uniform weights if empty list) + b : (nt,) numpy.ndarray, float64 + Target distribution (uniform weights if empty list) + M : (ns,nt) numpy.ndarray, float64 + Loss matrix (c-order array with type float64) + + Returns + ------- + alpha : (ns,) numpy.ndarray, float64 + Source corrected dual potential + beta : (nt,) numpy.ndarray, float64 + Target corrected dual potential + + """ + + # binary indexing of non-zeros weights + asel = a != 0 + bsel = b != 0 + + # compute dual constraints violation + constraint_violation = alpha0[:, None] + beta0[None, :] - M + + # Compute largest violation per line and columns + aviol = np.max(constraint_violation, 1) + bviol = np.max(constraint_violation, 0) + # update corrects violation of + alpha_up = -1 * ~asel * np.maximum(aviol, 0) + beta_up = -1 * ~bsel * np.maximum(bviol, 0) -def emd(a, b, M, numItermax=100000, log=False): + alpha = alpha0 + alpha_up + beta = beta0 + beta_up + + return center_ot_dual(alpha, beta, a, b) + + +def emd(a, b, M, numItermax=100000, log=False, center_dual=True): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -35,7 +180,9 @@ def emd(a, b, M, numItermax=100000, log=False): \gamma = arg\min_\gamma <\gamma,M>_F s.t. \gamma 1 = a + \gamma^T 1= b + \gamma\geq 0 where : @@ -43,7 +190,7 @@ def emd(a, b, M, numItermax=100000, log=False): - a and b are the sample weights .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 + Note that the M matrix needs to be a C-order numpy.array in float64 format. Uses the algorithm proposed in [1]_ @@ -62,6 +209,9 @@ def emd(a, b, M, numItermax=100000, log=False): log: bool, optional (default=False) If True, returns a dictionary containing the cost and dual variables. Otherwise returns only the optimal transportation matrix. + center_dual: boolean, optional (default=True) + If True, centers the dual potential using function + :ref:`center_ot_dual`. Returns ------- @@ -109,7 +259,20 @@ def emd(a, b, M, numItermax=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ + "Dimension mismatch, check dimensions of M with a and b" + + asel = a != 0 + bsel = b != 0 + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + result_code_string = check_result(result_code) if log: log = {} @@ -123,14 +286,17 @@ def emd(a, b, M, numItermax=100000, log=False): def emd2(a, b, M, processes=multiprocessing.cpu_count(), - numItermax=100000, log=False, return_matrix=False): + numItermax=100000, log=False, return_matrix=False, + center_dual=True): r"""Solves the Earth Movers distance problem and returns the loss .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + \min_\gamma <\gamma,M>_F s.t. \gamma 1 = a + \gamma^T 1= b + \gamma\geq 0 where : @@ -138,7 +304,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), - a and b are the sample weights .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 + Note that the M matrix needs to be a C-order numpy.array in float64 format. Uses the algorithm proposed in [1]_ @@ -161,6 +327,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), variables. Otherwise returns only the optimal transportation cost. return_matrix: boolean, optional (default=False) If True, returns the optimal transportation matrix in the log. + center_dual: boolean, optional (default=True) + If True, centers the dual potential using function + :ref:`center_ot_dual`. Returns ------- @@ -204,7 +373,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), # problem with pikling Forks if sys.platform.endswith('win32'): - processes=1 + processes = 1 # if empty array given then use uniform distributions if len(a) == 0: @@ -212,21 +381,43 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ + "Dimension mismatch, check dimensions of M with a and b" + + asel = a != 0 + if log or return_matrix: def f(b): - G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) - result_code_string = check_result(resultCode) + bsel = b != 0 + + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + + result_code_string = check_result(result_code) log = {} if return_matrix: log['G'] = G log['u'] = u log['v'] = v log['warning'] = result_code_string - log['result_code'] = resultCode + log['result_code'] = result_code return [cost, log] else: def f(b): + bsel = b != 0 G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + check_result(result_code) return cost @@ -234,7 +425,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), return f(b) nb = b.shape[1] - if processes>1: + if processes > 1: res = parmap(f, [b[:, i] for i in range(nb)], processes) else: res = list(map(f, [b[:, i].copy() for i in range(nb)])) @@ -242,8 +433,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), return res - -def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None): +def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, + stopThr=1e-7, verbose=False, log=None): """ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) @@ -295,7 +486,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None k = X_init.shape[0] d = X_init.shape[1] if b is None: - b = np.ones((k,))/k + b = np.ones((k,)) / k if weights is None: weights = np.ones((N,)) / N @@ -306,17 +497,17 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None displacement_square_norm = stopThr + 1. - while ( displacement_square_norm > stopThr and iter_count < numItermax ): + while (displacement_square_norm > stopThr and iter_count < numItermax): T_sum = np.zeros((k, d)) - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): - + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, + weights.tolist()): M_i = dist(X, measure_locations_i) T_i = emd(b, measure_weights_i, M_i) T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) - displacement_square_norm = np.sum(np.square(T_sum-X)) + displacement_square_norm = np.sum(np.square(T_sum - X)) if log: displacement_square_norms.append(displacement_square_norm) @@ -436,12 +627,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, if b.ndim == 0 or len(b) == 0: b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] - x_a_1d = x_a.reshape((-1, )) - x_b_1d = x_b.reshape((-1, )) + x_a_1d = x_a.reshape((-1,)) + x_b_1d = x_b.reshape((-1,)) perm_a = np.argsort(x_a_1d) perm_b = np.argsort(x_b_1d) - G_sorted, indices, cost = emd_1d_sorted(a, b, + G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b], x_a_1d[perm_a], x_b_1d[perm_b], metric=metric, p=p) G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])), diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 2b6c495..c167964 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -19,7 +19,7 @@ import warnings cdef extern from "EMD.h": - int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) + int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -35,8 +35,7 @@ def check_result(result_code): message = "numItermax reached before optimality. Try to increase numItermax." warnings.warn(message) return message - - + @cython.boundscheck(False) @cython.wraparound(False) def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter): @@ -61,6 +60,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod .. warning:: Note that the M matrix needs to be a C-order :py.cls:`numpy.array` + .. warning:: + The C++ solver discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport + matrix) is exact, the solver only returns feasible dual potentials + on the samples with weights different from zero. + Parameters ---------- a : (ns,) numpy.ndarray, float64 @@ -73,7 +78,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod The maximum number of iterations before stopping the optimization algorithm if it has not converged. - Returns ------- gamma: (ns x nt) numpy.ndarray @@ -82,12 +86,19 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod """ cdef int n1= M.shape[0] cdef int n2= M.shape[1] + cdef int nmax=n1+n2-1 + cdef int result_code = 0 + cdef int nG=0 cdef double cost=0 - cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2]) cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) + cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0]) + + cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0) + cdef np.ndarray[long, ndim=1, mode="c"] iG=np.zeros(0,dtype=np.int) + cdef np.ndarray[long, ndim=1, mode="c"] jG=np.zeros(0,dtype=np.int) if not len(a): a=np.ones((n1,))/n1 @@ -95,8 +106,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod if not len(b): b=np.ones((n2,))/n2 + # init OT matrix + G=np.zeros([n1, n2]) + # calling the function - cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) + with nogil: + result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) return G, cost, alpha, beta, result_code diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 7c6a4ce..5d93040 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -686,7 +686,7 @@ namespace lemon { /// \see resetParams(), reset() ProblemType run() { #if DEBUG_LVL>0 - std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED\n"; + std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED << "\n" ; #endif if (!init()) return INFEASIBLE; @@ -875,7 +875,7 @@ namespace lemon { c += Number(it->second) * Number(_cost[it->first]); return c;*/ - for (int i=0; i<_flow.size(); i++) + for (unsigned long i=0; i<_flow.size(); i++) c += _flow[i] * Number(_cost[i]); return c; @@ -1257,7 +1257,7 @@ namespace lemon { u = w; } _pred[u_in] = in_arc; - _forward[u_in] = (u_in == _source[in_arc]); + _forward[u_in] = ((unsigned int)u_in == _source[in_arc]); _succ_num[u_in] = old_succ_num; // Set limits for updating _last_succ form v_in and v_out @@ -1418,7 +1418,6 @@ namespace lemon { template <typename PivotRuleImpl> ProblemType start() { PivotRuleImpl pivot(*this); - double prevCost=-1; ProblemType retVal = OPTIMAL; // Perform heuristic initial pivots |