{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# 1D Wasserstein barycenter comparison between exact LP and entropic regularization\n\n\nThis example illustrates the computation of regularized Wasserstein Barycenter\nas proposed in [3] and exact LP barycenters using standard LP solver.\n\nIt reproduces approximately Figure 3.1 and 3.2 from the following paper:\nCuturi, M., & Peyr\u00e9, G. (2016). A smoothed dual approach for variational\nWasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343.\n\n[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyr\u00e9, G. (2015).\nIterative Bregman projections for regularized transportation problems\nSIAM Journal on Scientific Computing, 37(2), A1111-A1138.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Remi Flamary \n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\n# necessary for 3d plot even if not used\nfrom mpl_toolkits.mplot3d import Axes3D # noqa\nfrom matplotlib.collections import PolyCollection # noqa\n\n#import ot.lp.cvx as cvx" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Gaussian Data\n-------------\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#%% parameters\n\nproblems = []\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\n# Gaussian distributions\na1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std\na2 = ot.datasets.make_1D_gauss(n, m=60, s=8)\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n#%% barycenter computation\n\nalpha = 0.5 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Dirac Data\n----------\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#%% parameters\n\na1 = 1.0 * (x > 10) * (x < 50)\na2 = 1.0 * (x > 60) * (x < 80)\n\na1 /= a1.sum()\na2 /= a2.sum()\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n\n#%% barycenter computation\n\nalpha = 0.5 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()\n\n#%% parameters\n\na1 = np.zeros(n)\na2 = np.zeros(n)\n\na1[10] = .25\na1[20] = .5\na1[30] = .25\na2[80] = 1\n\n\na1 /= a1.sum()\na2 /= a2.sum()\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n\n#%% barycenter computation\n\nalpha = 0.5 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Final figure\n------------\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#%% plot\n\nnbm = len(problems)\nnbm2 = (nbm // 2)\n\n\npl.figure(2, (20, 6))\npl.clf()\n\nfor i in range(nbm):\n\n A = problems[i][0]\n bary_l2 = problems[i][1][0]\n bary_wass = problems[i][1][1]\n bary_wass2 = problems[i][1][2]\n\n pl.subplot(2, nbm, 1 + i)\n for j in range(n_distributions):\n pl.plot(x, A[:, j])\n if i == nbm2:\n pl.title('Distributions')\n pl.xticks(())\n pl.yticks(())\n\n pl.subplot(2, nbm, 1 + i + nbm)\n\n pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)')\n pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\n pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\n if i == nbm - 1:\n pl.legend()\n if i == nbm2:\n pl.title('Barycenters')\n\n pl.xticks(())\n pl.yticks(())" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 0 }