summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_barycenter_lp_vs_entropic.ipynb
blob: 21991623576e43d040349bd8d597a46a019b36b3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# 1D Wasserstein barycenter comparison between exact LP and entropic regularization\n\n\nThis example illustrates the computation of regularized Wasserstein Barycenter\nas proposed in [3] and exact LP barycenters using standard LP solver.\n\nIt reproduces approximately Figure 3.1 and 3.2 from the following paper:\nCuturi, M., & Peyr\u00e9, G. (2016). A smoothed dual approach for variational\nWasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343.\n\n[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyr\u00e9, G. (2015).\nIterative Bregman projections for regularized transportation problems\nSIAM Journal on Scientific Computing, 37(2), A1111-A1138.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\n# necessary for 3d plot even if not used\nfrom mpl_toolkits.mplot3d import Axes3D  # noqa\nfrom matplotlib.collections import PolyCollection  # noqa\n\n#import ot.lp.cvx as cvx"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Gaussian Data\n-------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "#%% parameters\n\nproblems = []\n\nn = 100  # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\n# Gaussian distributions\na1 = ot.datasets.make_1D_gauss(n, m=20, s=5)  # m= mean, s= std\na2 = ot.datasets.make_1D_gauss(n, m=60, s=8)\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n    pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n#%% barycenter computation\n\nalpha = 0.5  # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n    pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Dirac Data\n----------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "#%% parameters\n\na1 = 1.0 * (x > 10) * (x < 50)\na2 = 1.0 * (x > 60) * (x < 80)\n\na1 /= a1.sum()\na2 /= a2.sum()\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n    pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n\n#%% barycenter computation\n\nalpha = 0.5  # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n    pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()\n\n#%% parameters\n\na1 = np.zeros(n)\na2 = np.zeros(n)\n\na1[10] = .25\na1[20] = .5\na1[30] = .25\na2[80] = 1\n\n\na1 /= a1.sum()\na2 /= a2.sum()\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n    pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n\n#%% barycenter computation\n\nalpha = 0.5  # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n    pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Final figure\n------------\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "#%% plot\n\nnbm = len(problems)\nnbm2 = (nbm // 2)\n\n\npl.figure(2, (20, 6))\npl.clf()\n\nfor i in range(nbm):\n\n    A = problems[i][0]\n    bary_l2 = problems[i][1][0]\n    bary_wass = problems[i][1][1]\n    bary_wass2 = problems[i][1][2]\n\n    pl.subplot(2, nbm, 1 + i)\n    for j in range(n_distributions):\n        pl.plot(x, A[:, j])\n    if i == nbm2:\n        pl.title('Distributions')\n    pl.xticks(())\n    pl.yticks(())\n\n    pl.subplot(2, nbm, 1 + i + nbm)\n\n    pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)')\n    pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\n    pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\n    if i == nbm - 1:\n        pl.legend()\n    if i == nbm2:\n        pl.title('Barycenters')\n\n    pl.xticks(())\n    pl.yticks(())"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.5"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}