summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorEduardo Fernandes Montesuma <edumontesuma@gmail.com>2022-07-27 11:16:14 +0200
committerGitHub <noreply@github.com>2022-07-27 11:16:14 +0200
commit818c7ace20da36d8042b0d7ad7a712b27f7afd59 (patch)
tree58dd4e0c9f990ea0c851712d85748de99ce6b236 /ot
parent7c2a9523747c90aebfef711fdf34b5bbdb6f2f4d (diff)
[MRG] Free support Sinkhorn barycenters (#387)
* Adding function for computing Sinkhorn Free Support barycenters * Adding exampel on Free Support Sinkhorn Barycenter * Fixing typo on free support sinkhorn barycenter example * Adding info on new Free Support Barycenter solver * Removing extra line so that code follows pep8 * Fixing issues with pep8 in example * Correcting issues with pep8 standards * Adding tests for free support sinkhorn barycenter * Adding section on Sinkhorn barycenter to the example * Changing distributions for the Sinkhorn barycenter example * Removing file that should not be on the last commit * Adding PR number to REALEASES.md * Adding new contributors * Update CONTRIBUTORS.md Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py120
1 files changed, 120 insertions, 0 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 34dcadb..b1321a4 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1540,6 +1540,126 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)
+def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None,
+ numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None,
+ **kwargs):
+ r"""
+ Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally:
+
+ .. math::
+ \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i)
+
+ where :
+
+ - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
+ - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex)
+ - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations
+ - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
+
+ This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
+ There are two differences with the following codes:
+
+ - we do not optimize over the weights
+ - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
+ :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
+ implementation of the fixed-point algorithm of
+ :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting.
+ - at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the
+ transport plan in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
+
+ Parameters
+ ----------
+ measures_locations : list of N (k_i,d) array-like
+ The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
+ (:math:`k_i` can be different for each element of the list)
+ measures_weights : list of N (k_i,) array-like
+ Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
+ representing the weights of each discrete input measure
+
+ X_init : (k,d) array-like
+ Initialization of the support locations (on `k` atoms) of the barycenter
+ reg : float
+ Regularization term >0
+ b : (k,) array-like
+ Initialization of the weights of the barycenter (non-negatives, sum to 1)
+ weights : (N,) array-like
+ Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
+
+ numItermax : int, optional
+ Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterations when calculating the transport plans with Sinkhorn
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ X : (k,d) array-like
+ Support locations (on k atoms) of the barycenter
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT solver
+ ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming
+
+ .. _references-free-support-barycenter:
+ References
+ ----------
+ .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+ .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
+ """
+ nx = get_backend(*measures_locations, *measures_weights, X_init)
+
+ iter_count = 0
+
+ N = len(measures_locations)
+ k = X_init.shape[0]
+ d = X_init.shape[1]
+ if b is None:
+ b = nx.ones((k,), type_as=X_init) / k
+ if weights is None:
+ weights = nx.ones((N,), type_as=X_init) / N
+
+ X = X_init
+
+ log_dict = {}
+ displacement_square_norms = []
+
+ displacement_square_norm = stopThr + 1.
+
+ while (displacement_square_norm > stopThr and iter_count < numItermax):
+
+ T_sum = nx.zeros((k, d), type_as=X_init)
+
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
+ M_i = dist(X, measure_locations_i)
+ T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs)
+ T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i)
+
+ displacement_square_norm = nx.sum((T_sum - X) ** 2)
+ if log:
+ displacement_square_norms.append(displacement_square_norm)
+
+ X = T_sum
+
+ if verbose:
+ print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
+
+ iter_count += 1
+
+ if log:
+ log_dict['displacement_square_norms'] = displacement_square_norms
+ return X, log_dict
+ else:
+ return X
+
+
def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False, warn=True):
r"""Compute the entropic wasserstein barycenter in log-domain