From 5e7c6ab04be3dc2035ca2a7f9deab3bb3bfb8faa Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 25 Jun 2019 14:43:39 +0200 Subject: doc add examples unbalanced --- docs/cache_nbrun | 2 +- .../source/auto_examples/auto_examples_jupyter.zip | Bin 139016 -> 148147 bytes docs/source/auto_examples/auto_examples_python.zip | Bin 93470 -> 99229 bytes .../images/sphx_glr_plot_UOT_1D_001.png | Bin 0 -> 21239 bytes .../images/sphx_glr_plot_UOT_1D_002.png | Bin 0 -> 22051 bytes .../images/sphx_glr_plot_UOT_1D_006.png | Bin 0 -> 21288 bytes .../images/sphx_glr_plot_UOT_barycenter_1D_001.png | Bin 0 -> 22177 bytes .../images/sphx_glr_plot_UOT_barycenter_1D_003.png | Bin 0 -> 42539 bytes .../images/sphx_glr_plot_UOT_barycenter_1D_005.png | Bin 0 -> 105997 bytes .../images/sphx_glr_plot_UOT_barycenter_1D_006.png | Bin 0 -> 103234 bytes .../images/thumb/sphx_glr_plot_UOT_1D_thumb.png | Bin 0 -> 14761 bytes .../sphx_glr_plot_UOT_barycenter_1D_thumb.png | Bin 0 -> 15099 bytes docs/source/auto_examples/index.rst | 40 ++++ docs/source/auto_examples/plot_UOT_1D.ipynb | 108 +++++++++ docs/source/auto_examples/plot_UOT_1D.py | 76 ++++++ docs/source/auto_examples/plot_UOT_1D.rst | 173 ++++++++++++++ .../auto_examples/plot_UOT_barycenter_1D.ipynb | 126 ++++++++++ .../source/auto_examples/plot_UOT_barycenter_1D.py | 164 +++++++++++++ .../auto_examples/plot_UOT_barycenter_1D.rst | 261 +++++++++++++++++++++ 19 files changed, 949 insertions(+), 1 deletion(-) create mode 100644 docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_001.png create mode 100644 docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_002.png create mode 100644 docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_006.png create mode 100644 docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.png create mode 100644 docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.png create mode 100644 docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.png create mode 100644 docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.png create mode 100644 docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.png create mode 100644 docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.png create mode 100644 docs/source/auto_examples/plot_UOT_1D.ipynb create mode 100644 docs/source/auto_examples/plot_UOT_1D.py create mode 100644 docs/source/auto_examples/plot_UOT_1D.rst create mode 100644 docs/source/auto_examples/plot_UOT_barycenter_1D.ipynb create mode 100644 docs/source/auto_examples/plot_UOT_barycenter_1D.py create mode 100644 docs/source/auto_examples/plot_UOT_barycenter_1D.rst (limited to 'docs') diff --git a/docs/cache_nbrun b/docs/cache_nbrun index 04f6fce..8a95023 100644 --- a/docs/cache_nbrun +++ b/docs/cache_nbrun @@ -1 +1 @@ -{"plot_otda_color_images.ipynb": "f804d5806c7ac1a0901e4542b1eaa77b", "plot_WDA.ipynb": "27f8de4c6d7db46497076523673eedfb", "plot_OT_L1_vs_L2.ipynb": "5d565b8aaf03be4309eba731127851dc", "plot_otda_semi_supervised.ipynb": "f6dfb02ba2bbd939408ffcd22a3b007c", "plot_fgw.ipynb": "2ba3e100e92ecf4dfbeb605de20b40ab", "plot_otda_d2.ipynb": "e6feae588103f2a8fab942e5f4eff483", "plot_compute_emd.ipynb": "f5cd71cad882ec157dc8222721e9820c", "plot_barycenter_fgw.ipynb": "e14100dd276bff3ffdfdf176f1b6b070", "plot_convolutional_barycenter.ipynb": "a72bb3716a1baaffd81ae267a673f9b6", "plot_optim_OTreg.ipynb": "481801bb0d133ef350a65179cf8f739a", "plot_barycenter_lp_vs_entropic.ipynb": "51833e8c76aaedeba9599ac7a30eb357", "plot_OT_1D_smooth.ipynb": "3a059103652225a0c78ea53895cf79e5", "plot_barycenter_1D.ipynb": "5f6fb8aebd8e2e91ebc77c923cb112b3", "plot_otda_mapping.ipynb": "2f1ebbdc0f855d9e2b7adf9edec24d25", "plot_OT_1D.ipynb": "b5348bdc561c07ec168a1622e5af4b93", "plot_gromov_barycenter.ipynb": "953e5047b886ec69ec621ec52f5e21d1", "plot_otda_mapping_colors_images.ipynb": "cc8bf9a857f52e4a159fe71dfda19018", "plot_stochastic.ipynb": "e18253354c8c1d72567a4259eb1094f7", "plot_otda_linear_mapping.ipynb": "a472c767abe82020e0a58125a528785c", "plot_otda_classes.ipynb": "39087b6e98217851575f2271c22853a4", "plot_free_support_barycenter.ipynb": "246dd2feff4b233a4f1a553c5a202fdc", "plot_gromov.ipynb": "24f2aea489714d34779521f46d5e2c47", "plot_OT_2D_samples.ipynb": "912a77c5dd0fc0fafa03fac3d86f1502"} \ No newline at end of file +{"plot_otda_semi_supervised.ipynb": "f6dfb02ba2bbd939408ffcd22a3b007c", "plot_WDA.ipynb": "27f8de4c6d7db46497076523673eedfb", "plot_UOT_1D.ipynb": "fc7dd383e625597bd59fff03a8430c91", "plot_OT_L1_vs_L2.ipynb": "5d565b8aaf03be4309eba731127851dc", "plot_otda_color_images.ipynb": "f804d5806c7ac1a0901e4542b1eaa77b", "plot_fgw.ipynb": "2ba3e100e92ecf4dfbeb605de20b40ab", "plot_otda_d2.ipynb": "e6feae588103f2a8fab942e5f4eff483", "plot_compute_emd.ipynb": "f5cd71cad882ec157dc8222721e9820c", "plot_barycenter_fgw.ipynb": "e14100dd276bff3ffdfdf176f1b6b070", "plot_convolutional_barycenter.ipynb": "a72bb3716a1baaffd81ae267a673f9b6", "plot_optim_OTreg.ipynb": "481801bb0d133ef350a65179cf8f739a", "plot_barycenter_lp_vs_entropic.ipynb": "51833e8c76aaedeba9599ac7a30eb357", "plot_OT_1D_smooth.ipynb": "3a059103652225a0c78ea53895cf79e5", "plot_barycenter_1D.ipynb": "5f6fb8aebd8e2e91ebc77c923cb112b3", "plot_otda_mapping.ipynb": "2f1ebbdc0f855d9e2b7adf9edec24d25", "plot_OT_1D.ipynb": "b5348bdc561c07ec168a1622e5af4b93", "plot_gromov_barycenter.ipynb": "953e5047b886ec69ec621ec52f5e21d1", "plot_UOT_barycenter_1D.ipynb": "c72f0bfb6e1a79710dad3fef9f5c557c", "plot_otda_mapping_colors_images.ipynb": "cc8bf9a857f52e4a159fe71dfda19018", "plot_stochastic.ipynb": "e18253354c8c1d72567a4259eb1094f7", "plot_otda_linear_mapping.ipynb": "a472c767abe82020e0a58125a528785c", "plot_otda_classes.ipynb": "39087b6e98217851575f2271c22853a4", "plot_free_support_barycenter.ipynb": "246dd2feff4b233a4f1a553c5a202fdc", "plot_gromov.ipynb": "24f2aea489714d34779521f46d5e2c47", "plot_OT_2D_samples.ipynb": "912a77c5dd0fc0fafa03fac3d86f1502"} \ No newline at end of file diff --git a/docs/source/auto_examples/auto_examples_jupyter.zip b/docs/source/auto_examples/auto_examples_jupyter.zip index a3a7c29..901195a 100644 Binary files a/docs/source/auto_examples/auto_examples_jupyter.zip and b/docs/source/auto_examples/auto_examples_jupyter.zip differ diff --git a/docs/source/auto_examples/auto_examples_python.zip b/docs/source/auto_examples/auto_examples_python.zip index 86a6841..ded2613 100644 Binary files a/docs/source/auto_examples/auto_examples_python.zip and b/docs/source/auto_examples/auto_examples_python.zip differ diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_001.png new file mode 100644 index 0000000..69ef5b7 Binary files /dev/null and b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_001.png differ diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_002.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_002.png new file mode 100644 index 0000000..0407e44 Binary files /dev/null and b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_002.png differ diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_006.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_006.png new file mode 100644 index 0000000..f58d383 Binary files /dev/null and b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_006.png differ diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.png new file mode 100644 index 0000000..ec8c51e Binary files /dev/null and b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.png differ diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.png new file mode 100644 index 0000000..89ab265 Binary files /dev/null and b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.png differ diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.png new file mode 100644 index 0000000..c6c49cb Binary files /dev/null and b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.png differ diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.png new file mode 100644 index 0000000..8870b10 Binary files /dev/null and b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.png differ diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.png new file mode 100644 index 0000000..1d048f2 Binary files /dev/null and b/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.png differ diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.png new file mode 100644 index 0000000..999f175 Binary files /dev/null and b/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.png differ diff --git a/docs/source/auto_examples/index.rst b/docs/source/auto_examples/index.rst index 9f02da4..fe6702d 100644 --- a/docs/source/auto_examples/index.rst +++ b/docs/source/auto_examples/index.rst @@ -27,6 +27,26 @@ This is a gallery of all the POT example files. /auto_examples/plot_OT_1D +.. raw:: html + +
+ +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.png + + :ref:`sphx_glr_auto_examples_plot_UOT_1D.py` + +.. raw:: html + +
+ + +.. toctree:: + :hidden: + + /auto_examples/plot_UOT_1D + .. raw:: html
@@ -287,6 +307,26 @@ This is a gallery of all the POT example files. /auto_examples/plot_otda_mapping_colors_images +.. raw:: html + +
+ +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.png + + :ref:`sphx_glr_auto_examples_plot_UOT_barycenter_1D.py` + +.. raw:: html + +
+ + +.. toctree:: + :hidden: + + /auto_examples/plot_UOT_barycenter_1D + .. raw:: html
diff --git a/docs/source/auto_examples/plot_UOT_1D.ipynb b/docs/source/auto_examples/plot_UOT_1D.ipynb new file mode 100644 index 0000000..c695306 --- /dev/null +++ b/docs/source/auto_examples/plot_UOT_1D.ipynb @@ -0,0 +1,108 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# 1D Unbalanced optimal transport\n\n\nThis example illustrates the computation of Unbalanced Optimal transport\nusing a Kullback-Leibler relaxation.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Author: Hicham Janati \n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot\nfrom ot.datasets import make_1D_gauss as gauss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate data\n-------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "#%% parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na = gauss(n, m=20, s=5) # m= mean, s= std\nb = gauss(n, m=60, s=10)\n\n# make distributions unbalanced\nb *= 5.\n\n# loss matrix\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))\nM /= M.max()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot distributions and loss matrix\n----------------------------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\npl.plot(x, a, 'b', label='Source distribution')\npl.plot(x, b, 'r', label='Target distribution')\npl.legend()\n\n# plot distributions and loss matrix\n\npl.figure(2, figsize=(5, 5))\not.plot.plot1D_mat(a, b, M, 'Cost matrix M')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Solve Unbalanced Sinkhorn\n--------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Sinkhorn\n\nepsilon = 0.1 # entropy parameter\nalpha = 1. # Unbalanced KL relaxation parameter\nGs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)\n\npl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')\n\npl.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/docs/source/auto_examples/plot_UOT_1D.py b/docs/source/auto_examples/plot_UOT_1D.py new file mode 100644 index 0000000..2ea8b05 --- /dev/null +++ b/docs/source/auto_examples/plot_UOT_1D.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +""" +=============================== +1D Unbalanced optimal transport +=============================== + +This example illustrates the computation of Unbalanced Optimal transport +using a Kullback-Leibler relaxation. +""" + +# Author: Hicham Janati +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot +from ot.datasets import make_1D_gauss as gauss + +############################################################################## +# Generate data +# ------------- + + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# make distributions unbalanced +b *= 5. + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +# plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + + +############################################################################## +# Solve Unbalanced Sinkhorn +# -------------- + + +# Sinkhorn + +epsilon = 0.1 # entropy parameter +alpha = 1. # Unbalanced KL relaxation parameter +Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') + +pl.show() diff --git a/docs/source/auto_examples/plot_UOT_1D.rst b/docs/source/auto_examples/plot_UOT_1D.rst new file mode 100644 index 0000000..8e618b4 --- /dev/null +++ b/docs/source/auto_examples/plot_UOT_1D.rst @@ -0,0 +1,173 @@ + + +.. _sphx_glr_auto_examples_plot_UOT_1D.py: + + +=============================== +1D Unbalanced optimal transport +=============================== + +This example illustrates the computation of Unbalanced Optimal transport +using a Kullback-Leibler relaxation. + + + +.. code-block:: python + + + # Author: Hicham Janati + # + # License: MIT License + + import numpy as np + import matplotlib.pylab as pl + import ot + import ot.plot + from ot.datasets import make_1D_gauss as gauss + + + + + + + +Generate data +------------- + + + +.. code-block:: python + + + + #%% parameters + + n = 100 # nb bins + + # bin positions + x = np.arange(n, dtype=np.float64) + + # Gaussian distributions + a = gauss(n, m=20, s=5) # m= mean, s= std + b = gauss(n, m=60, s=10) + + # make distributions unbalanced + b *= 5. + + # loss matrix + M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) + M /= M.max() + + + + + + + + +Plot distributions and loss matrix +---------------------------------- + + + +.. code-block:: python + + + #%% plot the distributions + + pl.figure(1, figsize=(6.4, 3)) + pl.plot(x, a, 'b', label='Source distribution') + pl.plot(x, b, 'r', label='Target distribution') + pl.legend() + + # plot distributions and loss matrix + + pl.figure(2, figsize=(5, 5)) + ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + + + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_UOT_1D_001.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_UOT_1D_002.png + :scale: 47 + + + + +Solve Unbalanced Sinkhorn +-------------- + + + +.. code-block:: python + + + + # Sinkhorn + + epsilon = 0.1 # entropy parameter + alpha = 1. # Unbalanced KL relaxation parameter + Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) + + pl.figure(4, figsize=(5, 5)) + ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') + + pl.show() + + + +.. image:: /auto_examples/images/sphx_glr_plot_UOT_1D_006.png + :align: center + + +.. rst-class:: sphx-glr-script-out + + Out:: + + It. |Err + ------------------- + 0|1.838786e+00| + 10|1.242379e-01| + 20|2.581314e-03| + 30|5.674552e-05| + 40|1.252959e-06| + 50|2.768136e-08| + 60|6.116090e-10| + + +**Total running time of the script:** ( 0 minutes 0.259 seconds) + + + +.. only :: html + + .. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_UOT_1D.py ` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_UOT_1D.ipynb ` + + +.. only:: html + + .. rst-class:: sphx-glr-signature + + `Gallery generated by Sphinx-Gallery `_ diff --git a/docs/source/auto_examples/plot_UOT_barycenter_1D.ipynb b/docs/source/auto_examples/plot_UOT_barycenter_1D.ipynb new file mode 100644 index 0000000..e59cdc2 --- /dev/null +++ b/docs/source/auto_examples/plot_UOT_barycenter_1D.ipynb @@ -0,0 +1,126 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# 1D Wasserstein barycenter demo for Unbalanced distributions\n\n\nThis example illustrates the computation of regularized Wassersyein Barycenter\nas proposed in [10] for Unbalanced inputs.\n\n\n[10] Chizat, L., Peyr\u00e9, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.\n\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Author: Hicham Janati \n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\n# necessary for 3d plot even if not used\nfrom mpl_toolkits.mplot3d import Axes3D # noqa\nfrom matplotlib.collections import PolyCollection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate data\n-------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std\na2 = ot.datasets.make_1D_gauss(n, m=60, s=8)\n\n# make unbalanced dists\na2 *= 3.\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot data\n---------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Barycenter computation\n----------------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# non weighted barycenter computation\n\nweight = 0.5 # 0<=weight<=1\nweights = np.array([1 - weight, weight])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\nalpha = 1.\n\nbary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Barycentric interpolation\n-------------------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# barycenter interpolation\n\nn_weight = 11\nweight_list = np.linspace(0, 1, n_weight)\n\n\nB_l2 = np.zeros((n, n_weight))\n\nB_wass = np.copy(B_l2)\n\nfor i in range(0, n_weight):\n weight = weight_list[i]\n weights = np.array([1 - weight, weight])\n B_l2[:, i] = A.dot(weights)\n B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)\n\n\n# plot interpolation\n\npl.figure(3)\n\ncmap = pl.cm.get_cmap('viridis')\nverts = []\nzs = weight_list\nfor i, z in enumerate(zs):\n ys = B_l2[:, i]\n verts.append(list(zip(x, ys)))\n\nax = pl.gcf().gca(projection='3d')\n\npoly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir='y')\nax.set_xlabel('x')\nax.set_xlim3d(0, n)\nax.set_ylabel(r'$\\alpha$')\nax.set_ylim3d(0, 1)\nax.set_zlabel('')\nax.set_zlim3d(0, B_l2.max() * 1.01)\npl.title('Barycenter interpolation with l2')\npl.tight_layout()\n\npl.figure(4)\ncmap = pl.cm.get_cmap('viridis')\nverts = []\nzs = weight_list\nfor i, z in enumerate(zs):\n ys = B_wass[:, i]\n verts.append(list(zip(x, ys)))\n\nax = pl.gcf().gca(projection='3d')\n\npoly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir='y')\nax.set_xlabel('x')\nax.set_xlim3d(0, n)\nax.set_ylabel(r'$\\alpha$')\nax.set_ylim3d(0, 1)\nax.set_zlabel('')\nax.set_zlim3d(0, B_l2.max() * 1.01)\npl.title('Barycenter interpolation with Wasserstein')\npl.tight_layout()\n\npl.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/docs/source/auto_examples/plot_UOT_barycenter_1D.py b/docs/source/auto_examples/plot_UOT_barycenter_1D.py new file mode 100644 index 0000000..c8d9d3b --- /dev/null +++ b/docs/source/auto_examples/plot_UOT_barycenter_1D.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +""" +=========================================================== +1D Wasserstein barycenter demo for Unbalanced distributions +=========================================================== + +This example illustrates the computation of regularized Wassersyein Barycenter +as proposed in [10] for Unbalanced inputs. + + +[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + +""" + +# Author: Hicham Janati +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa +from matplotlib.collections import PolyCollection + +############################################################################## +# Generate data +# ------------- + +# parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + +# make unbalanced dists +a2 *= 3. + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + +############################################################################## +# Plot data +# --------- + +# plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +############################################################################## +# Barycenter computation +# ---------------------- + +# non weighted barycenter computation + +weight = 0.5 # 0<=weight<=1 +weights = np.array([1 - weight, weight]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +alpha = 1. + +bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights) + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + +############################################################################## +# Barycentric interpolation +# ------------------------- + +# barycenter interpolation + +n_weight = 11 +weight_list = np.linspace(0, 1, n_weight) + + +B_l2 = np.zeros((n, n_weight)) + +B_wass = np.copy(B_l2) + +for i in range(0, n_weight): + weight = weight_list[i] + weights = np.array([1 - weight, weight]) + B_l2[:, i] = A.dot(weights) + B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights) + + +# plot interpolation + +pl.figure(3) + +cmap = pl.cm.get_cmap('viridis') +verts = [] +zs = weight_list +for i, z in enumerate(zs): + ys = B_l2[:, i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel(r'$\alpha$') +ax.set_ylim3d(0, 1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max() * 1.01) +pl.title('Barycenter interpolation with l2') +pl.tight_layout() + +pl.figure(4) +cmap = pl.cm.get_cmap('viridis') +verts = [] +zs = weight_list +for i, z in enumerate(zs): + ys = B_wass[:, i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel(r'$\alpha$') +ax.set_ylim3d(0, 1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max() * 1.01) +pl.title('Barycenter interpolation with Wasserstein') +pl.tight_layout() + +pl.show() diff --git a/docs/source/auto_examples/plot_UOT_barycenter_1D.rst b/docs/source/auto_examples/plot_UOT_barycenter_1D.rst new file mode 100644 index 0000000..ac17587 --- /dev/null +++ b/docs/source/auto_examples/plot_UOT_barycenter_1D.rst @@ -0,0 +1,261 @@ + + +.. _sphx_glr_auto_examples_plot_UOT_barycenter_1D.py: + + +=========================================================== +1D Wasserstein barycenter demo for Unbalanced distributions +=========================================================== + +This example illustrates the computation of regularized Wassersyein Barycenter +as proposed in [10] for Unbalanced inputs. + + +[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + + + +.. code-block:: python + + + # Author: Hicham Janati + # + # License: MIT License + + import numpy as np + import matplotlib.pylab as pl + import ot + # necessary for 3d plot even if not used + from mpl_toolkits.mplot3d import Axes3D # noqa + from matplotlib.collections import PolyCollection + + + + + + + +Generate data +------------- + + + +.. code-block:: python + + + # parameters + + n = 100 # nb bins + + # bin positions + x = np.arange(n, dtype=np.float64) + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + + # make unbalanced dists + a2 *= 3. + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + n_distributions = A.shape[1] + + # loss matrix + normalization + M = ot.utils.dist0(n) + M /= M.max() + + + + + + + +Plot data +--------- + + + +.. code-block:: python + + + # plot the distributions + + pl.figure(1, figsize=(6.4, 3)) + for i in range(n_distributions): + pl.plot(x, A[:, i]) + pl.title('Distributions') + pl.tight_layout() + + + + +.. image:: /auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.png + :align: center + + + + +Barycenter computation +---------------------- + + + +.. code-block:: python + + + # non weighted barycenter computation + + weight = 0.5 # 0<=weight<=1 + weights = np.array([1 - weight, weight]) + + # l2bary + bary_l2 = A.dot(weights) + + # wasserstein + reg = 1e-3 + alpha = 1. + + bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights) + + pl.figure(2) + pl.clf() + pl.subplot(2, 1, 1) + for i in range(n_distributions): + pl.plot(x, A[:, i]) + pl.title('Distributions') + + pl.subplot(2, 1, 2) + pl.plot(x, bary_l2, 'r', label='l2') + pl.plot(x, bary_wass, 'g', label='Wasserstein') + pl.legend() + pl.title('Barycenters') + pl.tight_layout() + + + + +.. image:: /auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.png + :align: center + + + + +Barycentric interpolation +------------------------- + + + +.. code-block:: python + + + # barycenter interpolation + + n_weight = 11 + weight_list = np.linspace(0, 1, n_weight) + + + B_l2 = np.zeros((n, n_weight)) + + B_wass = np.copy(B_l2) + + for i in range(0, n_weight): + weight = weight_list[i] + weights = np.array([1 - weight, weight]) + B_l2[:, i] = A.dot(weights) + B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights) + + + # plot interpolation + + pl.figure(3) + + cmap = pl.cm.get_cmap('viridis') + verts = [] + zs = weight_list + for i, z in enumerate(zs): + ys = B_l2[:, i] + verts.append(list(zip(x, ys))) + + ax = pl.gcf().gca(projection='3d') + + poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) + poly.set_alpha(0.7) + ax.add_collection3d(poly, zs=zs, zdir='y') + ax.set_xlabel('x') + ax.set_xlim3d(0, n) + ax.set_ylabel(r'$\alpha$') + ax.set_ylim3d(0, 1) + ax.set_zlabel('') + ax.set_zlim3d(0, B_l2.max() * 1.01) + pl.title('Barycenter interpolation with l2') + pl.tight_layout() + + pl.figure(4) + cmap = pl.cm.get_cmap('viridis') + verts = [] + zs = weight_list + for i, z in enumerate(zs): + ys = B_wass[:, i] + verts.append(list(zip(x, ys))) + + ax = pl.gcf().gca(projection='3d') + + poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) + poly.set_alpha(0.7) + ax.add_collection3d(poly, zs=zs, zdir='y') + ax.set_xlabel('x') + ax.set_xlim3d(0, n) + ax.set_ylabel(r'$\alpha$') + ax.set_ylim3d(0, 1) + ax.set_zlabel('') + ax.set_zlim3d(0, B_l2.max() * 1.01) + pl.title('Barycenter interpolation with Wasserstein') + pl.tight_layout() + + pl.show() + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.png + :scale: 47 + + + + +**Total running time of the script:** ( 0 minutes 0.344 seconds) + + + +.. only :: html + + .. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_UOT_barycenter_1D.py ` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_UOT_barycenter_1D.ipynb ` + + +.. only:: html + + .. rst-class:: sphx-glr-signature + + `Gallery generated by Sphinx-Gallery `_ -- cgit v1.2.3