From 818c7ace20da36d8042b0d7ad7a712b27f7afd59 Mon Sep 17 00:00:00 2001 From: Eduardo Fernandes Montesuma Date: Wed, 27 Jul 2022 11:16:14 +0200 Subject: [MRG] Free support Sinkhorn barycenters (#387) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Adding function for computing Sinkhorn Free Support barycenters * Adding exampel on Free Support Sinkhorn Barycenter * Fixing typo on free support sinkhorn barycenter example * Adding info on new Free Support Barycenter solver * Removing extra line so that code follows pep8 * Fixing issues with pep8 in example * Correcting issues with pep8 standards * Adding tests for free support sinkhorn barycenter * Adding section on Sinkhorn barycenter to the example * Changing distributions for the Sinkhorn barycenter example * Removing file that should not be on the last commit * Adding PR number to REALEASES.md * Adding new contributors * Update CONTRIBUTORS.md Co-authored-by: Rémi Flamary --- ot/bregman.py | 120 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) (limited to 'ot') diff --git a/ot/bregman.py b/ot/bregman.py index 34dcadb..b1321a4 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1540,6 +1540,126 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, return geometricBar(weights, UKv) +def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None, + numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None, + **kwargs): + r""" + Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally: + + .. math:: + \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) + + where : + + - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one + - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) + - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations + - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter + + This problem is considered in :ref:`[20] ` (Algorithm 2). + There are two differences with the following codes: + + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in + :ref:`[20] ` (Algorithm 2). This can be seen as a discrete + implementation of the fixed-point algorithm of + :ref:`[43] ` proposed in the continuous setting. + - at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the + transport plan in :ref:`[20] ` (Algorithm 2). + + Parameters + ---------- + measures_locations : list of N (k_i,d) array-like + The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space + (:math:`k_i` can be different for each element of the list) + measures_weights : list of N (k_i,) array-like + Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one + representing the weights of each discrete input measure + + X_init : (k,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + reg : float + Regularization term >0 + b : (k,) array-like + Initialization of the weights of the barycenter (non-negatives, sum to 1) + weights : (N,) array-like + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterations when calculating the transport plans with Sinkhorn + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + X : (k,d) array-like + Support locations (on k atoms) of the barycenter + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT solver + ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming + + .. _references-free-support-barycenter: + References + ---------- + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + nx = get_backend(*measures_locations, *measures_weights, X_init) + + iter_count = 0 + + N = len(measures_locations) + k = X_init.shape[0] + d = X_init.shape[1] + if b is None: + b = nx.ones((k,), type_as=X_init) / k + if weights is None: + weights = nx.ones((N,), type_as=X_init) / N + + X = X_init + + log_dict = {} + displacement_square_norms = [] + + displacement_square_norm = stopThr + 1. + + while (displacement_square_norm > stopThr and iter_count < numItermax): + + T_sum = nx.zeros((k, d), type_as=X_init) + + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): + M_i = dist(X, measure_locations_i) + T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs) + T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i) + + displacement_square_norm = nx.sum((T_sum - X) ** 2) + if log: + displacement_square_norms.append(displacement_square_norm) + + X = T_sum + + if verbose: + print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) + + iter_count += 1 + + if log: + log_dict['displacement_square_norms'] = displacement_square_norms + return X, log_dict + else: + return X + + def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic wasserstein barycenter in log-domain -- cgit v1.2.3