diff options
author | Gard Spreemann <gspr@nonempty.org> | 2023-06-14 16:51:31 +0200 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2023-06-14 16:51:31 +0200 |
commit | 96788a3fe5601e4c3f49b592aa0d9c034247862e (patch) | |
tree | 5ee3ebcdea05f6766fc9858344913e40487e9067 /ot/lp/__init__.py | |
parent | 35bd2c98b642df78638d7d733bc1a89d873db1de (diff) | |
parent | 89f1613861152432807077fbb146578611dc5888 (diff) |
Merge tag '0.9.0' into dfsg/latestdfsg/latest
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 161 |
1 files changed, 139 insertions, 22 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 390c32d..2ff02ab 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Solvers for the original linear program OT problem +Solvers for the original linear program OT problem. """ @@ -20,16 +20,17 @@ from .cvx import barycenter # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from .solver_1d import emd_1d, emd2_1d, wasserstein_1d +from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d, + binary_search_circle, wasserstein_circle, + semidiscrete_wasserstein2_unif_circle) from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend - - __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', - 'emd_1d', 'emd2_1d', 'wasserstein_1d'] + 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter', + 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle'] def check_number_threads(numThreads): @@ -232,6 +233,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): If this behaviour is unwanted, please make sure to provide a floating point input. + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + Uses the algorithm proposed in :ref:`[1] <references-emd>`. Parameters @@ -391,6 +394,8 @@ def emd2(a, b, M, processes=1, If this behaviour is unwanted, please make sure to provide a floating point input. + .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + Uses the algorithm proposed in :ref:`[1] <references-emd2>`. Parameters @@ -483,6 +488,11 @@ def emd2(a, b, M, processes=1, assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" + # ensure that same mass + np.testing.assert_almost_equal(a.sum(0), + b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum') + b = b * a.sum(0) / b.sum(0,keepdims=True) + asel = a != 0 numThreads = check_number_threads(numThreads) @@ -517,8 +527,8 @@ def emd2(a, b, M, processes=1, log['warning'] = result_code_string log['result_code'] = result_code cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), (log['u'] - nx.mean(log['u']), - log['v'] - nx.mean(log['v']), G)) + (a0, b0, M0), (log['u'] - nx.mean(log['u']), + log['v'] - nx.mean(log['v']), G)) return [cost, log] else: def f(b): @@ -572,18 +582,18 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None where : - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one - - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i` - - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations + - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) + - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter - This problem is considered in :ref:`[1] <references-free-support-barycenter>` (Algorithm 2). + This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). There are two differences with the following codes: - we do not optimize over the weights - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in - :ref:`[1] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete + :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of - :ref:`[2] <references-free-support-barycenter>` proposed in the continuous setting. + :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting. Parameters ---------- @@ -623,13 +633,13 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None .. _references-free-support-barycenter: References ---------- - .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. """ - nx = get_backend(*measures_locations,*measures_weights,X_init) + nx = get_backend(*measures_locations, *measures_weights, X_init) iter_count = 0 @@ -637,9 +647,9 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None k = X_init.shape[0] d = X_init.shape[1] if b is None: - b = nx.ones((k,),type_as=X_init) / k + b = nx.ones((k,), type_as=X_init) / k if weights is None: - weights = nx.ones((N,),type_as=X_init) / N + weights = nx.ones((N,), type_as=X_init) / N X = X_init @@ -650,15 +660,14 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None while (displacement_square_norm > stopThr and iter_count < numItermax): - T_sum = nx.zeros((k, d),type_as=X_init) - + T_sum = nx.zeros((k, d), type_as=X_init) - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): M_i = dist(X, measure_locations_i) T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) - T_sum = T_sum + weight_i * 1. / b[:,None] * nx.dot(T_i, measure_locations_i) + T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i) - displacement_square_norm = nx.sum((T_sum - X)**2) + displacement_square_norm = nx.sum((T_sum - X) ** 2) if log: displacement_square_norms.append(displacement_square_norm) @@ -675,3 +684,111 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None else: return X + +def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, Y_init=None, b=None, weights=None, + numItermax=100, stopThr=1e-7, verbose=False, log=None, numThreads=1, eps=0): + r""" + Solves the free support generalised Wasserstein barycenter problem: finding a barycenter (a discrete measure with + a fixed amount of points of uniform weights) whose respective projections fit the input measures. + More formally: + + .. math:: + \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma) + + where : + + - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d` + - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter + - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}` + - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex) + - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations + - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex) + - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}` + + As show by :ref:`[42] <references-generalized-free-support-barycenter>`, + this problem can be re-written as a Wasserstein Barycenter problem, + which we solve using the free support method :ref:`[20] <references-generalized-free-support-barycenter>` + (Algorithm 2). + + Parameters + ---------- + X_list : list of p (k_i,d_i) array-like + Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space + (:math:`k_i` can be different for each element of the list) + a_list : list of p (k_i,) array-like + Measure weights: each element is a vector (k_i) on the simplex + P_list : list of p (d_i,d) array-like + Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}` + n_samples_bary : int + Number of barycenter points + Y_init : (n_samples_bary,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + b : (n_samples_bary,) array-like + Initialization of the weights of the barycenter measure (on the simplex) + weights : (p,) array-like + Initialization of the coefficients of the barycenter (on the simplex) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + eps: Stability coefficient for the change of variable matrix inversion + If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix + inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense) + + + Returns + ------- + Y : (n_samples_bary,d) array-like + Support locations (on n_samples_bary atoms) of the barycenter + + + .. _references-generalized-free-support-barycenter: + References + ---------- + .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021. + + """ + nx = get_backend(*X_list, *a_list, *P_list) + d = P_list[0].shape[1] + p = len(X_list) + + if weights is None: + weights = nx.ones(p, type_as=X_list[0]) / p + + # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB) + A = eps * nx.eye(d, type_as=X_list[0]) # if eps nonzero: will force the invertibility of A + for (P_i, lambda_i) in zip(P_list, weights): + A = A + lambda_i * P_i.T @ P_i + B = nx.inv(nx.sqrtm(A)) + + Z_list = [x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list)] # change of variables -> (WB) problem on Z + + if Y_init is None: + Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0]) + + if b is None: + b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimised + + out = free_support_barycenter(Z_list, a_list, Y_init, b, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, numThreads=numThreads) + + if log: # unpack + Y, log_dict = out + else: + Y = out + log_dict = None + Y = Y @ B.T # return to the Generalised WB formulation + + if log: + return Y, log_dict + else: + return Y |