From ccc076e0fc535b2c734214c0ac1936e9e2cbeb62 Mon Sep 17 00:00:00 2001 From: eloitanguy <69361683+eloitanguy@users.noreply.github.com> Date: Fri, 6 May 2022 08:43:21 +0200 Subject: [WIP] Generalized Wasserstein Barycenters (#372) * GWB first solver version * tests + example for gwb (untested) + free_bar doc fix * improved doc, fixed minor bugs, better example visu * minor doc + visu fixes * plot GWB pep8 fix * fixed partial gromov test reproductibility * added an animation for the GWB visu * added PR num * minor doc fixes + better gwb logo --- CONTRIBUTORS.md | 1 + README.md | 6 +- RELEASES.md | 6 + .../backends/plot_sliced_wass_grad_flow_pytorch.py | 6 +- .../plot_generalized_free_support_barycenter.py | 152 +++++++++++++++++++++ ot/__init__.py | 2 +- ot/lp/__init__.py | 145 +++++++++++++++++--- test/test_ot.py | 40 ++++++ test/test_partial.py | 1 + 9 files changed, 334 insertions(+), 25 deletions(-) create mode 100644 examples/barycenters/plot_generalized_free_support_barycenter.py diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index ab64fba..0909b14 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -37,6 +37,7 @@ The contributors to this library are: * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) +* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) ## Acknowledgments diff --git a/README.md b/README.md index e2b33d9..12340d5 100644 --- a/README.md +++ b/README.md @@ -288,4 +288,8 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. -[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) \ No newline at end of file +[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) + +[42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. + +[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index be2192e..3461832 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,11 @@ # Releases +## 0.8.3dev + +#### New features + +- Added Generalized Wasserstein Barycenter solver + example (PR #372) + ## 0.8.2 diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py index cf5d64d..59e0042 100644 --- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -103,7 +103,7 @@ ax = pl.axis() # %% # Animate trajectories of the gradient flow along iteration -# ------------------------------------------------------- +# --------------------------------------------------------- pl.figure(3, (8, 4)) @@ -122,7 +122,7 @@ ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, # %% # Compute the Sliced Wasserstein Barycenter -# +# ----------------------------------------- x1_torch = torch.tensor(x1).to(device=device) x3_torch = torch.tensor(x3).to(device=device) xbinit = np.random.randn(500, 2) * 10 + 16 @@ -169,7 +169,7 @@ ax = pl.axis() # %% # Animate trajectories of the barycenter along gradient descent -# ------------------------------------------------------- +# ------------------------------------------------------------- pl.figure(5, (8, 4)) diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py new file mode 100644 index 0000000..9af1953 --- /dev/null +++ b/examples/barycenters/plot_generalized_free_support_barycenter.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +""" +======================================= +Generalized Wasserstein Barycenter Demo +======================================= + +This example illustrates the computation of Generalized Wasserstein Barycenter +as proposed in [42]. + + +[42] Delon, J., Gozlan, N., and Saint-Dizier, A.. +Generalized Wasserstein barycenters between probability measures living on different subspaces. +arXiv preprint arXiv:2105.09755, 2021. + +""" + +# Author: Eloi Tanguy +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.pylab as pl +import ot +import matplotlib.animation as animation + +######################## +# Generate and plot data +# ---------------------- + +# Input measures +sub_sample_factor = 8 +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] +I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] +I3 = pl.imread('../../data/heart.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] + +sz = I1.shape[0] +UU, VV = np.meshgrid(np.arange(sz), np.arange(sz)) + +# Input measure locations in their respective 2D spaces +X_list = [np.stack((UU[im == 0], VV[im == 0]), 1) * 1.0 for im in [I1, I2, I3]] + +# Input measure weights +a_list = [ot.unif(x.shape[0]) for x in X_list] + +# Projections 3D -> 2D +P1 = np.array([[1, 0, 0], [0, 1, 0]]) +P2 = np.array([[0, 1, 0], [0, 0, 1]]) +P3 = np.array([[1, 0, 0], [0, 0, 1]]) +P_list = [P1, P2, P3] + +# Barycenter weights +weights = np.array([1 / 3, 1 / 3, 1 / 3]) + +# Number of barycenter points to compute +n_samples_bary = 150 + +# Send the input measures into 3D space for visualisation +X_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)] + +# Plot the input data +fig = plt.figure(figsize=(3, 3)) +axis = fig.add_subplot(1, 1, 1, projection="3d") +for Xi in X_visu: + axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +axis.view_init(azim=45) +axis.set_xticks([]) +axis.set_yticks([]) +axis.set_zticks([]) +plt.show() + +################################# +# Barycenter computation and plot +# ------------------------------- + +Y = ot.lp.generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary) +fig = plt.figure(figsize=(3, 3)) + +axis = fig.add_subplot(1, 1, 1, projection="3d") +for Xi in X_visu: + axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +axis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +axis.view_init(azim=45) +axis.set_xticks([]) +axis.set_yticks([]) +axis.set_zticks([]) +plt.show() + + +############################# +# Plotting projection matches +# --------------------------- + +fig = plt.figure(figsize=(9, 3)) + +ax = fig.add_subplot(1, 3, 1, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=0, azim=0) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +ax = fig.add_subplot(1, 3, 2, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=0, azim=90) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +ax = fig.add_subplot(1, 3, 3, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=90, azim=0) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +plt.tight_layout() +plt.show() + +############################################## +# Rotation animation +# -------------------------------------------- + +fig = plt.figure(figsize=(7, 7)) +ax = fig.add_subplot(1, 1, 1, projection="3d") + + +def _init(): + for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) + ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) + ax.view_init(elev=0, azim=0) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_zticks([]) + return fig, + + +def _update_plot(i): + ax.view_init(elev=i, azim=4 * i) + return fig, + + +ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=90, interval=50, blit=True, repeat_delay=2000) diff --git a/ot/__init__.py b/ot/__init__.py index 86ed94e..15d8351 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -51,7 +51,7 @@ from .factored import factored_optimal_transport # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.8.2" +__version__ = "0.8.3dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 390c32d..572781d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -26,10 +26,8 @@ from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend - - __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', - 'emd_1d', 'emd2_1d', 'wasserstein_1d'] + 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter'] def check_number_threads(numThreads): @@ -517,8 +515,8 @@ def emd2(a, b, M, processes=1, log['warning'] = result_code_string log['result_code'] = result_code cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), (log['u'] - nx.mean(log['u']), - log['v'] - nx.mean(log['v']), G)) + (a0, b0, M0), (log['u'] - nx.mean(log['u']), + log['v'] - nx.mean(log['v']), G)) return [cost, log] else: def f(b): @@ -572,18 +570,18 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None where : - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one - - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i` - - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations + - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) + - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter - This problem is considered in :ref:`[1] ` (Algorithm 2). + This problem is considered in :ref:`[20] ` (Algorithm 2). There are two differences with the following codes: - we do not optimize over the weights - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in - :ref:`[1] ` (Algorithm 2). This can be seen as a discrete + :ref:`[20] ` (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of - :ref:`[2] ` proposed in the continuous setting. + :ref:`[43] ` proposed in the continuous setting. Parameters ---------- @@ -623,13 +621,13 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None .. _references-free-support-barycenter: References ---------- - .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. """ - nx = get_backend(*measures_locations,*measures_weights,X_init) + nx = get_backend(*measures_locations, *measures_weights, X_init) iter_count = 0 @@ -637,9 +635,9 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None k = X_init.shape[0] d = X_init.shape[1] if b is None: - b = nx.ones((k,),type_as=X_init) / k + b = nx.ones((k,), type_as=X_init) / k if weights is None: - weights = nx.ones((N,),type_as=X_init) / N + weights = nx.ones((N,), type_as=X_init) / N X = X_init @@ -650,15 +648,14 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None while (displacement_square_norm > stopThr and iter_count < numItermax): - T_sum = nx.zeros((k, d),type_as=X_init) - + T_sum = nx.zeros((k, d), type_as=X_init) - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): M_i = dist(X, measure_locations_i) T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) - T_sum = T_sum + weight_i * 1. / b[:,None] * nx.dot(T_i, measure_locations_i) + T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i) - displacement_square_norm = nx.sum((T_sum - X)**2) + displacement_square_norm = nx.sum((T_sum - X) ** 2) if log: displacement_square_norms.append(displacement_square_norm) @@ -675,3 +672,111 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None else: return X + +def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, Y_init=None, b=None, weights=None, + numItermax=100, stopThr=1e-7, verbose=False, log=None, numThreads=1, eps=0): + r""" + Solves the free support generalised Wasserstein barycenter problem: finding a barycenter (a discrete measure with + a fixed amount of points of uniform weights) whose respective projections fit the input measures. + More formally: + + .. math:: + \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma) + + where : + + - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d` + - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter + - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}` + - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex) + - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations + - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex) + - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}` + + As show by :ref:`[42] `, + this problem can be re-written as a Wasserstein Barycenter problem, + which we solve using the free support method :ref:`[20] ` + (Algorithm 2). + + Parameters + ---------- + X_list : list of p (k_i,d_i) array-like + Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space + (:math:`k_i` can be different for each element of the list) + a_list : list of p (k_i,) array-like + Measure weights: each element is a vector (k_i) on the simplex + P_list : list of p (d_i,d) array-like + Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}` + n_samples_bary : int + Number of barycenter points + Y_init : (n_samples_bary,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + b : (n_samples_bary,) array-like + Initialization of the weights of the barycenter measure (on the simplex) + weights : (p,) array-like + Initialization of the coefficients of the barycenter (on the simplex) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + eps: Stability coefficient for the change of variable matrix inversion + If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix + inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense) + + + Returns + ------- + Y : (n_samples_bary,d) array-like + Support locations (on n_samples_bary atoms) of the barycenter + + + .. _references-generalized-free-support-barycenter: + References + ---------- + .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021. + + """ + nx = get_backend(*X_list, *a_list, *P_list) + d = P_list[0].shape[1] + p = len(X_list) + + if weights is None: + weights = nx.ones(p, type_as=X_list[0]) / p + + # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB) + A = eps * nx.eye(d, type_as=X_list[0]) # if eps nonzero: will force the invertibility of A + for (P_i, lambda_i) in zip(P_list, weights): + A = A + lambda_i * P_i.T @ P_i + B = nx.inv(nx.sqrtm(A)) + + Z_list = [x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list)] # change of variables -> (WB) problem on Z + + if Y_init is None: + Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0]) + + if b is None: + b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimised + + out = free_support_barycenter(Z_list, a_list, Y_init, b, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, numThreads=numThreads) + + if log: # unpack + Y, log_dict = out + else: + Y = out + log_dict = None + Y = Y @ B.T # return to the Generalised WB formulation + + if log: + return Y, log_dict + else: + return Y diff --git a/test/test_ot.py b/test/test_ot.py index bf832f6..ba3ef6a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -320,6 +320,46 @@ def test_free_support_barycenter_backends(nx): np.testing.assert_allclose(X, nx.to_numpy(X2)) +def test_generalised_free_support_barycenter(): + np.random.seed(42) # random inits + X = [np.array([-1., -1.]).reshape((1, 2)), np.array([1., 1.]).reshape((1, 2))] # two 2D points bar is obviously 0 + a = [np.array([1.]), np.array([1.])] + + P = [np.eye(2), np.eye(2)] + + Y_init = np.array([-12., 7.]).reshape((1, 2)) + + # obvious barycenter location between two 2D diracs + Y_true = np.array([0., .0]).reshape((1, 2)) + + # test without log and no init + Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1) + np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7) + + # test with log and init + Y, _ = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init, b=np.array([1.]), log=True) + np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7) + + +def test_generalised_free_support_barycenter_backends(nx): + np.random.seed(42) + X = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + a = [np.array([1.]), np.array([1.])] + P = [np.array([1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + Y_init = np.array([-12.]).reshape((1, 1)) + + Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init) + + X2 = nx.from_numpy(*X) + a2 = nx.from_numpy(*a) + P2 = nx.from_numpy(*P) + Y_init2 = nx.from_numpy(Y_init) + + Y2 = ot.lp.generalized_free_support_barycenter(X2, a2, P2, 1, Y_init=Y_init2) + + np.testing.assert_allclose(Y, nx.to_numpy(Y2)) + + @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None] diff --git a/test/test_partial.py b/test/test_partial.py index 97c611b..e07377b 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -137,6 +137,7 @@ def test_partial_wasserstein(): def test_partial_gromov_wasserstein(): + np.random.seed(42) n_samples = 20 # nb samples n_noise = 10 # nb of samples (noise) -- cgit v1.2.3