diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2017-09-15 14:54:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-09-15 14:54:21 +0200 |
commit | 81b2796226f3abde29fc024752728444da77509a (patch) | |
tree | c52cec3c38552f9f8c15361758aa9a80c30c3ef3 /docs/source/auto_examples/plot_OT_1D.ipynb | |
parent | e70d5420204db78691af2d0fbe04cc3d4416a8f4 (diff) | |
parent | 7fea2cd3e8ad29bf3fa442d7642bae124ee2bab0 (diff) |
Merge pull request #27 from rflamary/autonb
auto notebooks + release update (fixes #16)
Diffstat (limited to 'docs/source/auto_examples/plot_OT_1D.ipynb')
-rw-r--r-- | docs/source/auto_examples/plot_OT_1D.ipynb | 76 |
1 files changed, 74 insertions, 2 deletions
diff --git a/docs/source/auto_examples/plot_OT_1D.ipynb b/docs/source/auto_examples/plot_OT_1D.ipynb index 8715b97..26748c2 100644 --- a/docs/source/auto_examples/plot_OT_1D.ipynb +++ b/docs/source/auto_examples/plot_OT_1D.ipynb @@ -15,7 +15,7 @@ }, { "source": [ - "\n# 1D optimal transport\n\n\n@author: rflamary\n\n" + "\n# 1D optimal transport\n\n\nThis example illustrates the computation of EMD and Sinkhorn transport plans\nand their visualization.\n\n\n" ], "cell_type": "markdown", "metadata": {} @@ -24,7 +24,79 @@ "execution_count": null, "cell_type": "code", "source": [ - "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss\n\n\n#%% parameters\n\nn=100 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\n# Gaussian distributions\na=gauss(n,m=20,s=5) # m= mean, s= std\nb=gauss(n,m=60,s=10)\n\n# loss matrix\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)))\nM/=M.max()\n\n#%% plot the distributions\n\npl.figure(1)\npl.plot(x,a,'b',label='Source distribution')\npl.plot(x,b,'r',label='Target distribution')\npl.legend()\n\n#%% plot distributions and loss matrix\n\npl.figure(2)\not.plot.plot1D_mat(a,b,M,'Cost matrix M')\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,G0,'OT matrix G0')\n\n#%% Sinkhorn\n\nlambd=1e-3\nGs=ot.sinkhorn(a,b,M,lambd,verbose=True)\n\npl.figure(4)\not.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')" + "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "Generate data\n-------------\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "#%% parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na = gauss(n, m=20, s=5) # m= mean, s= std\nb = gauss(n, m=60, s=10)\n\n# loss matrix\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))\nM /= M.max()" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "Plot distributions and loss matrix\n----------------------------------\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\npl.plot(x, a, 'b', label='Source distribution')\npl.plot(x, b, 'r', label='Target distribution')\npl.legend()\n\n#%% plot distributions and loss matrix\n\npl.figure(2, figsize=(5, 5))\not.plot.plot1D_mat(a, b, M, 'Cost matrix M')" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "Solve EMD\n---------\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "#%% EMD\n\nG0 = ot.emd(a, b, M)\n\npl.figure(3, figsize=(5, 5))\not.plot.plot1D_mat(a, b, G0, 'OT matrix G0')" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "Solve Sinkhorn\n--------------\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "#%% Sinkhorn\n\nlambd = 1e-3\nGs = ot.sinkhorn(a, b, M, lambd, verbose=True)\n\npl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')\n\npl.show()" ], "outputs": [], "metadata": { |