diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-12-02 15:38:59 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-12-02 15:38:59 +0100 |
commit | e458b7a58d9790e7c5ff40dea235402d9c4c8662 (patch) | |
tree | ac9da575654c78aa04a177723603935051b5d42d /docs/source | |
parent | 7609f9e6a4103e13beb294873f4dac562b1d45e1 (diff) |
add doc for gallery
Diffstat (limited to 'docs/source')
77 files changed, 4163 insertions, 1 deletions
diff --git a/docs/source/auto_examples/auto_examples_jupyter.zip b/docs/source/auto_examples/auto_examples_jupyter.zip Binary files differnew file mode 100644 index 0000000..4840aa7 --- /dev/null +++ b/docs/source/auto_examples/auto_examples_jupyter.zip diff --git a/docs/source/auto_examples/auto_examples_python.zip b/docs/source/auto_examples/auto_examples_python.zip Binary files differnew file mode 100644 index 0000000..6e73f21 --- /dev/null +++ b/docs/source/auto_examples/auto_examples_python.zip diff --git a/docs/source/auto_examples/demo_OT_1D_test.ipynb b/docs/source/auto_examples/demo_OT_1D_test.ipynb new file mode 100644 index 0000000..87317ea --- /dev/null +++ b/docs/source/auto_examples/demo_OT_1D_test.ipynb @@ -0,0 +1,54 @@ +{ + "nbformat_minor": 0, + "nbformat": 4, + "cells": [ + { + "execution_count": null, + "cell_type": "code", + "source": [ + "%matplotlib inline" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "\nDemo for 1D optimal transport\n\n@author: rflamary\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss\n\n\n#%% parameters\n\nn=100 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\n# Gaussian distributions\na=gauss(n,m=n*.2,s=5) # m= mean, s= std\nb=gauss(n,m=n*.6,s=10)\n\n# loss matrix\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)))\nM/=M.max()\n\n#%% plot the distributions\n\npl.figure(1)\npl.plot(x,a,'b',label='Source distribution')\npl.plot(x,b,'r',label='Target distribution')\npl.legend()\n\n#%% plot distributions and loss matrix\n\npl.figure(2)\not.plot.plot1D_mat(a,b,M,'Cost matrix M')\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,G0,'OT matrix G0')\n\n#%% Sinkhorn\n\nlambd=1e-3\nGs=ot.sinkhorn(a,b,M,lambd,verbose=True)\n\npl.figure(4)\not.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')\n\n#%% Sinkhorn\n\nlambd=1e-4\nGss,log=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True)\nGss2,log2=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True,warmstart=log['warmstart'])\n\npl.figure(5)\not.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized')\n\n#%% Sinkhorn\n\nlambd=1e-11\nGss=ot.bregman.sinkhorn_epsilon_scaling(a,b,M,lambd,verbose=True)\n\npl.figure(5)\not.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized')" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "name": "python2", + "language": "python" + }, + "language_info": { + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "name": "python", + "file_extension": ".py", + "version": "2.7.12", + "pygments_lexer": "ipython2", + "codemirror_mode": { + "version": 2, + "name": "ipython" + } + } + } +}
\ No newline at end of file diff --git a/docs/source/auto_examples/demo_OT_1D_test.py b/docs/source/auto_examples/demo_OT_1D_test.py new file mode 100644 index 0000000..9edc377 --- /dev/null +++ b/docs/source/auto_examples/demo_OT_1D_test.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +""" +Demo for 1D optimal transport + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +import ot +from ot.datasets import get_1D_gauss as gauss + + +#%% parameters + +n=100 # nb bins + +# bin positions +x=np.arange(n,dtype=np.float64) + +# Gaussian distributions +a=gauss(n,m=n*.2,s=5) # m= mean, s= std +b=gauss(n,m=n*.6,s=10) + +# loss matrix +M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) +M/=M.max() + +#%% plot the distributions + +pl.figure(1) +pl.plot(x,a,'b',label='Source distribution') +pl.plot(x,b,'r',label='Target distribution') +pl.legend() + +#%% plot distributions and loss matrix + +pl.figure(2) +ot.plot.plot1D_mat(a,b,M,'Cost matrix M') + +#%% EMD + +G0=ot.emd(a,b,M) + +pl.figure(3) +ot.plot.plot1D_mat(a,b,G0,'OT matrix G0') + +#%% Sinkhorn + +lambd=1e-3 +Gs=ot.sinkhorn(a,b,M,lambd,verbose=True) + +pl.figure(4) +ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn') + +#%% Sinkhorn + +lambd=1e-4 +Gss,log=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True) +Gss2,log2=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True,warmstart=log['warmstart']) + +pl.figure(5) +ot.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized') + +#%% Sinkhorn + +lambd=1e-11 +Gss=ot.bregman.sinkhorn_epsilon_scaling(a,b,M,lambd,verbose=True) + +pl.figure(5) +ot.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized') diff --git a/docs/source/auto_examples/demo_OT_1D_test.rst b/docs/source/auto_examples/demo_OT_1D_test.rst new file mode 100644 index 0000000..aebeb1d --- /dev/null +++ b/docs/source/auto_examples/demo_OT_1D_test.rst @@ -0,0 +1,99 @@ + + +.. _sphx_glr_auto_examples_demo_OT_1D_test.py: + + +Demo for 1D optimal transport + +@author: rflamary + + + +.. code-block:: python + + + import numpy as np + import matplotlib.pylab as pl + import ot + from ot.datasets import get_1D_gauss as gauss + + + #%% parameters + + n=100 # nb bins + + # bin positions + x=np.arange(n,dtype=np.float64) + + # Gaussian distributions + a=gauss(n,m=n*.2,s=5) # m= mean, s= std + b=gauss(n,m=n*.6,s=10) + + # loss matrix + M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) + M/=M.max() + + #%% plot the distributions + + pl.figure(1) + pl.plot(x,a,'b',label='Source distribution') + pl.plot(x,b,'r',label='Target distribution') + pl.legend() + + #%% plot distributions and loss matrix + + pl.figure(2) + ot.plot.plot1D_mat(a,b,M,'Cost matrix M') + + #%% EMD + + G0=ot.emd(a,b,M) + + pl.figure(3) + ot.plot.plot1D_mat(a,b,G0,'OT matrix G0') + + #%% Sinkhorn + + lambd=1e-3 + Gs=ot.sinkhorn(a,b,M,lambd,verbose=True) + + pl.figure(4) + ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn') + + #%% Sinkhorn + + lambd=1e-4 + Gss,log=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True) + Gss2,log2=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True,warmstart=log['warmstart']) + + pl.figure(5) + ot.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized') + + #%% Sinkhorn + + lambd=1e-11 + Gss=ot.bregman.sinkhorn_epsilon_scaling(a,b,M,lambd,verbose=True) + + pl.figure(5) + ot.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized') + +**Total running time of the script:** ( 0 minutes 0.000 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: demo_OT_1D_test.py <demo_OT_1D_test.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: demo_OT_1D_test.ipynb <demo_OT_1D_test.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/demo_OT_2D_sampleslarge.ipynb b/docs/source/auto_examples/demo_OT_2D_sampleslarge.ipynb new file mode 100644 index 0000000..584a936 --- /dev/null +++ b/docs/source/auto_examples/demo_OT_2D_sampleslarge.ipynb @@ -0,0 +1,54 @@ +{ + "nbformat_minor": 0, + "nbformat": 4, + "cells": [ + { + "execution_count": null, + "cell_type": "code", + "source": [ + "%matplotlib inline" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "\nDemo for 2D Optimal transport between empirical distributions\n\n@author: rflamary\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n#%% parameters and data generation\n\nn=5000 # nb samples\n\nmu_s=np.array([0,0])\ncov_s=np.array([[1,0],[0,1]])\n\nmu_t=np.array([4,4])\ncov_t=np.array([[1,-.8],[-.8,1]])\n\nxs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)\nxt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)\n\na,b = ot.unif(n),ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM=ot.dist(xs,xt)\nM/=M.max()\n\n#%% plot samples\n\n#pl.figure(1)\n#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n#pl.legend(loc=0)\n#pl.title('Source and traget distributions')\n#\n#pl.figure(2)\n#pl.imshow(M,interpolation='nearest')\n#pl.title('Cost matrix M')\n#\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\n#pl.figure(3)\n#pl.imshow(G0,interpolation='nearest')\n#pl.title('OT matrix G0')\n#\n#pl.figure(4)\n#ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\n#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n#pl.legend(loc=0)\n#pl.title('OT matrix with samples')\n\n\n#%% sinkhorn\n\n# reg term\nlambd=5e-3\n\nGs=ot.sinkhorn(a,b,M,lambd)\n\n#pl.figure(5)\n#pl.imshow(Gs,interpolation='nearest')\n#pl.title('OT matrix sinkhorn')\n#\n#pl.figure(6)\n#ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])\n#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n#pl.legend(loc=0)\n#pl.title('OT matrix Sinkhorn with samples')\n#" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "name": "python2", + "language": "python" + }, + "language_info": { + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "name": "python", + "file_extension": ".py", + "version": "2.7.12", + "pygments_lexer": "ipython2", + "codemirror_mode": { + "version": 2, + "name": "ipython" + } + } + } +}
\ No newline at end of file diff --git a/docs/source/auto_examples/demo_OT_2D_sampleslarge.py b/docs/source/auto_examples/demo_OT_2D_sampleslarge.py new file mode 100644 index 0000000..ee3e8f7 --- /dev/null +++ b/docs/source/auto_examples/demo_OT_2D_sampleslarge.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +""" +Demo for 2D Optimal transport between empirical distributions + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + +#%% parameters and data generation + +n=5000 # nb samples + +mu_s=np.array([0,0]) +cov_s=np.array([[1,0],[0,1]]) + +mu_t=np.array([4,4]) +cov_t=np.array([[1,-.8],[-.8,1]]) + +xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s) +xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t) + +a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples + +# loss matrix +M=ot.dist(xs,xt) +M/=M.max() + +#%% plot samples + +#pl.figure(1) +#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') +#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') +#pl.legend(loc=0) +#pl.title('Source and traget distributions') +# +#pl.figure(2) +#pl.imshow(M,interpolation='nearest') +#pl.title('Cost matrix M') +# + +#%% EMD + +G0=ot.emd(a,b,M) + +#pl.figure(3) +#pl.imshow(G0,interpolation='nearest') +#pl.title('OT matrix G0') +# +#pl.figure(4) +#ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) +#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') +#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') +#pl.legend(loc=0) +#pl.title('OT matrix with samples') + + +#%% sinkhorn + +# reg term +lambd=5e-3 + +Gs=ot.sinkhorn(a,b,M,lambd) + +#pl.figure(5) +#pl.imshow(Gs,interpolation='nearest') +#pl.title('OT matrix sinkhorn') +# +#pl.figure(6) +#ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1]) +#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') +#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') +#pl.legend(loc=0) +#pl.title('OT matrix Sinkhorn with samples') +# + diff --git a/docs/source/auto_examples/demo_OT_2D_sampleslarge.rst b/docs/source/auto_examples/demo_OT_2D_sampleslarge.rst new file mode 100644 index 0000000..f5dbb0d --- /dev/null +++ b/docs/source/auto_examples/demo_OT_2D_sampleslarge.rst @@ -0,0 +1,106 @@ + + +.. _sphx_glr_auto_examples_demo_OT_2D_sampleslarge.py: + + +Demo for 2D Optimal transport between empirical distributions + +@author: rflamary + + + +.. code-block:: python + + + import numpy as np + import matplotlib.pylab as pl + import ot + + #%% parameters and data generation + + n=5000 # nb samples + + mu_s=np.array([0,0]) + cov_s=np.array([[1,0],[0,1]]) + + mu_t=np.array([4,4]) + cov_t=np.array([[1,-.8],[-.8,1]]) + + xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s) + xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t) + + a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples + + # loss matrix + M=ot.dist(xs,xt) + M/=M.max() + + #%% plot samples + + #pl.figure(1) + #pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + #pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + #pl.legend(loc=0) + #pl.title('Source and traget distributions') + # + #pl.figure(2) + #pl.imshow(M,interpolation='nearest') + #pl.title('Cost matrix M') + # + + #%% EMD + + G0=ot.emd(a,b,M) + + #pl.figure(3) + #pl.imshow(G0,interpolation='nearest') + #pl.title('OT matrix G0') + # + #pl.figure(4) + #ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) + #pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + #pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + #pl.legend(loc=0) + #pl.title('OT matrix with samples') + + + #%% sinkhorn + + # reg term + lambd=5e-3 + + Gs=ot.sinkhorn(a,b,M,lambd) + + #pl.figure(5) + #pl.imshow(Gs,interpolation='nearest') + #pl.title('OT matrix sinkhorn') + # + #pl.figure(6) + #ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1]) + #pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + #pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + #pl.legend(loc=0) + #pl.title('OT matrix Sinkhorn with samples') + # + + +**Total running time of the script:** ( 0 minutes 0.000 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: demo_OT_2D_sampleslarge.py <demo_OT_2D_sampleslarge.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: demo_OT_2D_sampleslarge.ipynb <demo_OT_2D_sampleslarge.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_001.png Binary files differnew file mode 100644 index 0000000..7de2b45 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_002.png Binary files differnew file mode 100644 index 0000000..dc34efd --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_002.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_003.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_003.png Binary files differnew file mode 100644 index 0000000..fbd72d5 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_003.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_004.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_004.png Binary files differnew file mode 100644 index 0000000..227812d --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_004.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_classes_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_classes_001.png Binary files differnew file mode 100644 index 0000000..2bf4015 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_classes_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_classes_004.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_classes_004.png Binary files differnew file mode 100644 index 0000000..c1fbf57 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_classes_004.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_color_images_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_color_images_001.png Binary files differnew file mode 100644 index 0000000..36bc769 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_color_images_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_color_images_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_color_images_002.png Binary files differnew file mode 100644 index 0000000..307e384 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_color_images_002.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_001.png Binary files differnew file mode 100644 index 0000000..8c700ee --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_002.png Binary files differnew file mode 100644 index 0000000..792b404 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_002.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_001.png Binary files differnew file mode 100644 index 0000000..36bc769 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_002.png Binary files differnew file mode 100644 index 0000000..008bf15 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_002.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.png Binary files differnew file mode 100644 index 0000000..9b0c3f5 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.png Binary files differnew file mode 100644 index 0000000..f2cf4f7 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_003.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_003.png Binary files differnew file mode 100644 index 0000000..a4252cf --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_003.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_004.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_004.png Binary files differnew file mode 100644 index 0000000..9bddc80 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_004.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png Binary files differnew file mode 100644 index 0000000..43b50e8 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png Binary files differnew file mode 100644 index 0000000..5651ebd --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_003.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_003.png Binary files differnew file mode 100644 index 0000000..b6ca0f4 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_003.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_004.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_004.png Binary files differnew file mode 100644 index 0000000..ab54a41 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_004.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png Binary files differnew file mode 100644 index 0000000..24453e1 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png Binary files differnew file mode 100644 index 0000000..d1b00e7 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_001.png Binary files differnew file mode 100644 index 0000000..be71674 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_002.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_002.png Binary files differnew file mode 100644 index 0000000..f62240b --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_002.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_003.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_003.png Binary files differnew file mode 100644 index 0000000..11f08b2 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_003.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_004.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_004.png Binary files differnew file mode 100644 index 0000000..b4e8f71 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_004.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_003.png b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_003.png Binary files differnew file mode 100644 index 0000000..e12ebc6 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_003.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_004.png b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_004.png Binary files differnew file mode 100644 index 0000000..1d57b8d --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_004.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_005.png b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_005.png Binary files differnew file mode 100644 index 0000000..4f5f2d4 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_005.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_demo_OT_1D_test_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_demo_OT_1D_test_thumb.png Binary files differnew file mode 100644 index 0000000..cbc8e0f --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_demo_OT_1D_test_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_demo_OT_2D_sampleslarge_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_demo_OT_2D_sampleslarge_thumb.png Binary files differnew file mode 100644 index 0000000..cbc8e0f --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_demo_OT_2D_sampleslarge_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_2D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_2D_thumb.png Binary files differnew file mode 100644 index 0000000..d15269d --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_2D_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_classes_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_classes_thumb.png Binary files differnew file mode 100644 index 0000000..5863d02 --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_classes_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_color_images_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_color_images_thumb.png Binary files differnew file mode 100644 index 0000000..5bb43c4 --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_color_images_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_color_images_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_color_images_thumb.png Binary files differnew file mode 100644 index 0000000..5bb43c4 --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_color_images_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_thumb.png Binary files differnew file mode 100644 index 0000000..c3d9a65 --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png Binary files differnew file mode 100644 index 0000000..9d0ec13 --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png Binary files differnew file mode 100644 index 0000000..0ff6768 --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.png Binary files differnew file mode 100644 index 0000000..86ff19f --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png Binary files differnew file mode 100644 index 0000000..cbc8e0f --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png diff --git a/docs/source/auto_examples/index.rst b/docs/source/auto_examples/index.rst new file mode 100644 index 0000000..6fac8c0 --- /dev/null +++ b/docs/source/auto_examples/index.rst @@ -0,0 +1,204 @@ +POT Examples +============ + +.. raw:: html + + <div class="sphx-glr-thumbcontainer" tooltip="@author: rflamary "> + +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png + + :ref:`sphx_glr_auto_examples_plot_OT_1D.py` + +.. raw:: html + + </div> + + +.. toctree:: + :hidden: + + /auto_examples/plot_OT_1D + +.. raw:: html + + <div class="sphx-glr-thumbcontainer" tooltip=" "> + +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png + + :ref:`sphx_glr_auto_examples_plot_optim_OTreg.py` + +.. raw:: html + + </div> + + +.. toctree:: + :hidden: + + /auto_examples/plot_optim_OTreg + +.. raw:: html + + <div class="sphx-glr-thumbcontainer" tooltip="@author: rflamary "> + +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png + + :ref:`sphx_glr_auto_examples_plot_OT_2D_samples.py` + +.. raw:: html + + </div> + + +.. toctree:: + :hidden: + + /auto_examples/plot_OT_2D_samples + +.. raw:: html + + <div class="sphx-glr-thumbcontainer" tooltip="[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized discrete optima..."> + +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OTDA_color_images_thumb.png + + :ref:`sphx_glr_auto_examples_plot_OTDA_color_images.py` + +.. raw:: html + + </div> + + +.. toctree:: + :hidden: + + /auto_examples/plot_OTDA_color_images + +.. raw:: html + + <div class="sphx-glr-thumbcontainer" tooltip=""> + +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OTDA_classes_thumb.png + + :ref:`sphx_glr_auto_examples_plot_OTDA_classes.py` + +.. raw:: html + + </div> + + +.. toctree:: + :hidden: + + /auto_examples/plot_OTDA_classes + +.. raw:: html + + <div class="sphx-glr-thumbcontainer" tooltip=""> + +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OTDA_2D_thumb.png + + :ref:`sphx_glr_auto_examples_plot_OTDA_2D.py` + +.. raw:: html + + </div> + + +.. toctree:: + :hidden: + + /auto_examples/plot_OTDA_2D + +.. raw:: html + + <div class="sphx-glr-thumbcontainer" tooltip=" @author: rflamary "> + +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.png + + :ref:`sphx_glr_auto_examples_plot_barycenter_1D.py` + +.. raw:: html + + </div> + + +.. toctree:: + :hidden: + + /auto_examples/plot_barycenter_1D + +.. raw:: html + + <div class="sphx-glr-thumbcontainer" tooltip="[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized discrete op..."> + +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_color_images_thumb.png + + :ref:`sphx_glr_auto_examples_plot_OTDA_mapping_color_images.py` + +.. raw:: html + + </div> + + +.. toctree:: + :hidden: + + /auto_examples/plot_OTDA_mapping_color_images + +.. raw:: html + + <div class="sphx-glr-thumbcontainer" tooltip="[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal ..."> + +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_thumb.png + + :ref:`sphx_glr_auto_examples_plot_OTDA_mapping.py` + +.. raw:: html + + </div> + + +.. toctree:: + :hidden: + + /auto_examples/plot_OTDA_mapping +.. raw:: html + + <div style='clear:both'></div> + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download all examples in Python source code: auto_examples_python.zip </auto_examples/auto_examples_python.zip>` + + + + .. container:: sphx-glr-download + + :download:`Download all examples in Jupyter notebooks: auto_examples_jupyter.zip </auto_examples/auto_examples_jupyter.zip>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/plot_OTDA_2D.ipynb b/docs/source/auto_examples/plot_OTDA_2D.ipynb new file mode 100644 index 0000000..2ffb256 --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_2D.ipynb @@ -0,0 +1,54 @@ +{ + "nbformat_minor": 0, + "nbformat": 4, + "cells": [ + { + "execution_count": null, + "cell_type": "code", + "source": [ + "%matplotlib inline" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "\n# OT for empirical distributions\n\n\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n\n\n#%% parameters\n\nn=150 # nb bins\n\nxs,ys=ot.datasets.get_data_classif('3gauss',n)\nxt,yt=ot.datasets.get_data_classif('3gauss2',n)\n\na,b = ot.unif(n),ot.unif(n)\n# loss matrix\nM=ot.dist(xs,xt)\n#M/=M.max()\n\n#%% plot samples\n\npl.figure(1)\n\npl.subplot(2,2,1)\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.legend(loc=0)\npl.title('Source distributions')\n\npl.subplot(2,2,2)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\npl.legend(loc=0)\npl.title('target distributions')\n\npl.figure(2)\npl.imshow(M,interpolation='nearest')\npl.title('Cost matrix M')\n\n\n#%% OT estimation\n\n# EMD\nG0=ot.emd(a,b,M)\n\n# sinkhorn\nlambd=1e-1\nGs=ot.sinkhorn(a,b,M,lambd)\n\n\n# Group lasso regularization\nreg=1e-1\neta=1e0\nGg=ot.da.sinkhorn_lpl1_mm(a,ys.astype(np.int),b,M,reg,eta)\n\n\n#%% visu matrices\n\npl.figure(3)\n\npl.subplot(2,3,1)\npl.imshow(G0,interpolation='nearest')\npl.title('OT matrix ')\n\npl.subplot(2,3,2)\npl.imshow(Gs,interpolation='nearest')\npl.title('OT matrix Sinkhorn')\n\npl.subplot(2,3,3)\npl.imshow(Gg,interpolation='nearest')\npl.title('OT matrix Group lasso')\n\npl.subplot(2,3,4)\not.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\n\n\npl.subplot(2,3,5)\not.plot.plot2D_samples_mat(xs,xt,Gs,c=[.5,.5,1])\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\n\npl.subplot(2,3,6)\not.plot.plot2D_samples_mat(xs,xt,Gg,c=[.5,.5,1])\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\n\n#%% sample interpolation\n\nxst0=n*G0.dot(xt)\nxsts=n*Gs.dot(xt)\nxstg=n*Gg.dot(xt)\n\npl.figure(4)\npl.subplot(2,3,1)\n\n\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)\npl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples')\npl.legend(loc=0)\n\npl.subplot(2,3,2)\n\n\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)\npl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples Sinkhorn')\n\npl.subplot(2,3,3)\n\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)\npl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples Grouplasso')" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "name": "python2", + "language": "python" + }, + "language_info": { + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "name": "python", + "file_extension": ".py", + "version": "2.7.12", + "pygments_lexer": "ipython2", + "codemirror_mode": { + "version": 2, + "name": "ipython" + } + } + } +}
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_OTDA_2D.py b/docs/source/auto_examples/plot_OTDA_2D.py new file mode 100644 index 0000000..a1fb804 --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_2D.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +""" +============================== +OT for empirical distributions +============================== + +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + + + +#%% parameters + +n=150 # nb bins + +xs,ys=ot.datasets.get_data_classif('3gauss',n) +xt,yt=ot.datasets.get_data_classif('3gauss2',n) + +a,b = ot.unif(n),ot.unif(n) +# loss matrix +M=ot.dist(xs,xt) +#M/=M.max() + +#%% plot samples + +pl.figure(1) + +pl.subplot(2,2,1) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.legend(loc=0) +pl.title('Source distributions') + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') +pl.legend(loc=0) +pl.title('target distributions') + +pl.figure(2) +pl.imshow(M,interpolation='nearest') +pl.title('Cost matrix M') + + +#%% OT estimation + +# EMD +G0=ot.emd(a,b,M) + +# sinkhorn +lambd=1e-1 +Gs=ot.sinkhorn(a,b,M,lambd) + + +# Group lasso regularization +reg=1e-1 +eta=1e0 +Gg=ot.da.sinkhorn_lpl1_mm(a,ys.astype(np.int),b,M,reg,eta) + + +#%% visu matrices + +pl.figure(3) + +pl.subplot(2,3,1) +pl.imshow(G0,interpolation='nearest') +pl.title('OT matrix ') + +pl.subplot(2,3,2) +pl.imshow(Gs,interpolation='nearest') +pl.title('OT matrix Sinkhorn') + +pl.subplot(2,3,3) +pl.imshow(Gg,interpolation='nearest') +pl.title('OT matrix Group lasso') + +pl.subplot(2,3,4) +ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + + +pl.subplot(2,3,5) +ot.plot.plot2D_samples_mat(xs,xt,Gs,c=[.5,.5,1]) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + +pl.subplot(2,3,6) +ot.plot.plot2D_samples_mat(xs,xt,Gg,c=[.5,.5,1]) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + +#%% sample interpolation + +xst0=n*G0.dot(xt) +xsts=n*Gs.dot(xt) +xstg=n*Gg.dot(xt) + +pl.figure(4) +pl.subplot(2,3,1) + + +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples') +pl.legend(loc=0) + +pl.subplot(2,3,2) + + +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) +pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples Sinkhorn') + +pl.subplot(2,3,3) + +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) +pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples Grouplasso')
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_OTDA_2D.rst b/docs/source/auto_examples/plot_OTDA_2D.rst new file mode 100644 index 0000000..b535bb0 --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_2D.rst @@ -0,0 +1,175 @@ + + +.. _sphx_glr_auto_examples_plot_OTDA_2D.py: + + +============================== +OT for empirical distributions +============================== + + + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OTDA_2D_001.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OTDA_2D_002.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OTDA_2D_003.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OTDA_2D_004.png + :scale: 47 + + + + + +.. code-block:: python + + + import numpy as np + import matplotlib.pylab as pl + import ot + + + + #%% parameters + + n=150 # nb bins + + xs,ys=ot.datasets.get_data_classif('3gauss',n) + xt,yt=ot.datasets.get_data_classif('3gauss2',n) + + a,b = ot.unif(n),ot.unif(n) + # loss matrix + M=ot.dist(xs,xt) + #M/=M.max() + + #%% plot samples + + pl.figure(1) + + pl.subplot(2,2,1) + pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') + pl.legend(loc=0) + pl.title('Source distributions') + + pl.subplot(2,2,2) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + pl.legend(loc=0) + pl.title('target distributions') + + pl.figure(2) + pl.imshow(M,interpolation='nearest') + pl.title('Cost matrix M') + + + #%% OT estimation + + # EMD + G0=ot.emd(a,b,M) + + # sinkhorn + lambd=1e-1 + Gs=ot.sinkhorn(a,b,M,lambd) + + + # Group lasso regularization + reg=1e-1 + eta=1e0 + Gg=ot.da.sinkhorn_lpl1_mm(a,ys.astype(np.int),b,M,reg,eta) + + + #%% visu matrices + + pl.figure(3) + + pl.subplot(2,3,1) + pl.imshow(G0,interpolation='nearest') + pl.title('OT matrix ') + + pl.subplot(2,3,2) + pl.imshow(Gs,interpolation='nearest') + pl.title('OT matrix Sinkhorn') + + pl.subplot(2,3,3) + pl.imshow(Gg,interpolation='nearest') + pl.title('OT matrix Group lasso') + + pl.subplot(2,3,4) + ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) + pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + + + pl.subplot(2,3,5) + ot.plot.plot2D_samples_mat(xs,xt,Gs,c=[.5,.5,1]) + pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + + pl.subplot(2,3,6) + ot.plot.plot2D_samples_mat(xs,xt,Gg,c=[.5,.5,1]) + pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + + #%% sample interpolation + + xst0=n*G0.dot(xt) + xsts=n*Gs.dot(xt) + xstg=n*Gg.dot(xt) + + pl.figure(4) + pl.subplot(2,3,1) + + + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) + pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30) + pl.title('Interp samples') + pl.legend(loc=0) + + pl.subplot(2,3,2) + + + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) + pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30) + pl.title('Interp samples Sinkhorn') + + pl.subplot(2,3,3) + + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) + pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30) + pl.title('Interp samples Grouplasso') +**Total running time of the script:** ( 0 minutes 17.372 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_OTDA_2D.py <plot_OTDA_2D.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_OTDA_2D.ipynb <plot_OTDA_2D.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/plot_OTDA_classes.ipynb b/docs/source/auto_examples/plot_OTDA_classes.ipynb new file mode 100644 index 0000000..d9fcb87 --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_classes.ipynb @@ -0,0 +1,54 @@ +{ + "nbformat_minor": 0, + "nbformat": 4, + "cells": [ + { + "execution_count": null, + "cell_type": "code", + "source": [ + "%matplotlib inline" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "\n# OT for domain adaptation\n\n\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "import matplotlib.pylab as pl\nimport ot\n\n\n\n\n#%% parameters\n\nn=150 # nb samples in source and target datasets\n\nxs,ys=ot.datasets.get_data_classif('3gauss',n)\nxt,yt=ot.datasets.get_data_classif('3gauss2',n)\n\n\n\n\n#%% plot samples\n\npl.figure(1)\n\npl.subplot(2,2,1)\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.legend(loc=0)\npl.title('Source distributions')\n\npl.subplot(2,2,2)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\npl.legend(loc=0)\npl.title('target distributions')\n\n\n#%% OT estimation\n\n# LP problem\nda_emd=ot.da.OTDA() # init class\nda_emd.fit(xs,xt) # fit distributions\nxst0=da_emd.interp() # interpolation of source samples\n\n\n# sinkhorn regularization\nlambd=1e-1\nda_entrop=ot.da.OTDA_sinkhorn()\nda_entrop.fit(xs,xt,reg=lambd)\nxsts=da_entrop.interp()\n\n# non-convex Group lasso regularization\nreg=1e-1\neta=1e0\nda_lpl1=ot.da.OTDA_lpl1()\nda_lpl1.fit(xs,ys,xt,reg=reg,eta=eta)\nxstg=da_lpl1.interp()\n\n\n# True Group lasso regularization\nreg=1e-1\neta=2e0\nda_l1l2=ot.da.OTDA_l1l2()\nda_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True)\nxstgl=da_l1l2.interp()\n\n\n#%% plot interpolated source samples\npl.figure(4,(15,8))\n\nparam_img={'interpolation':'nearest','cmap':'jet'}\n\npl.subplot(2,4,1)\npl.imshow(da_emd.G,**param_img)\npl.title('OT matrix')\n\n\npl.subplot(2,4,2)\npl.imshow(da_entrop.G,**param_img)\npl.title('OT matrix sinkhorn')\n\npl.subplot(2,4,3)\npl.imshow(da_lpl1.G,**param_img)\npl.title('OT matrix non-convex Group Lasso')\n\npl.subplot(2,4,4)\npl.imshow(da_l1l2.G,**param_img)\npl.title('OT matrix Group Lasso')\n\n\npl.subplot(2,4,5)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)\npl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples')\npl.legend(loc=0)\n\npl.subplot(2,4,6)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)\npl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples Sinkhorn')\n\npl.subplot(2,4,7)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)\npl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples non-convex Group Lasso')\n\npl.subplot(2,4,8)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)\npl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples Group Lasso')" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "name": "python2", + "language": "python" + }, + "language_info": { + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "name": "python", + "file_extension": ".py", + "version": "2.7.12", + "pygments_lexer": "ipython2", + "codemirror_mode": { + "version": 2, + "name": "ipython" + } + } + } +}
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_OTDA_classes.py b/docs/source/auto_examples/plot_OTDA_classes.py new file mode 100644 index 0000000..089b45b --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_classes.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +""" +======================== +OT for domain adaptation +======================== + +""" + +import matplotlib.pylab as pl +import ot + + + + +#%% parameters + +n=150 # nb samples in source and target datasets + +xs,ys=ot.datasets.get_data_classif('3gauss',n) +xt,yt=ot.datasets.get_data_classif('3gauss2',n) + + + + +#%% plot samples + +pl.figure(1) + +pl.subplot(2,2,1) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.legend(loc=0) +pl.title('Source distributions') + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') +pl.legend(loc=0) +pl.title('target distributions') + + +#%% OT estimation + +# LP problem +da_emd=ot.da.OTDA() # init class +da_emd.fit(xs,xt) # fit distributions +xst0=da_emd.interp() # interpolation of source samples + + +# sinkhorn regularization +lambd=1e-1 +da_entrop=ot.da.OTDA_sinkhorn() +da_entrop.fit(xs,xt,reg=lambd) +xsts=da_entrop.interp() + +# non-convex Group lasso regularization +reg=1e-1 +eta=1e0 +da_lpl1=ot.da.OTDA_lpl1() +da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta) +xstg=da_lpl1.interp() + + +# True Group lasso regularization +reg=1e-1 +eta=2e0 +da_l1l2=ot.da.OTDA_l1l2() +da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True) +xstgl=da_l1l2.interp() + + +#%% plot interpolated source samples +pl.figure(4,(15,8)) + +param_img={'interpolation':'nearest','cmap':'jet'} + +pl.subplot(2,4,1) +pl.imshow(da_emd.G,**param_img) +pl.title('OT matrix') + + +pl.subplot(2,4,2) +pl.imshow(da_entrop.G,**param_img) +pl.title('OT matrix sinkhorn') + +pl.subplot(2,4,3) +pl.imshow(da_lpl1.G,**param_img) +pl.title('OT matrix non-convex Group Lasso') + +pl.subplot(2,4,4) +pl.imshow(da_l1l2.G,**param_img) +pl.title('OT matrix Group Lasso') + + +pl.subplot(2,4,5) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples') +pl.legend(loc=0) + +pl.subplot(2,4,6) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) +pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples Sinkhorn') + +pl.subplot(2,4,7) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) +pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples non-convex Group Lasso') + +pl.subplot(2,4,8) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) +pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples Group Lasso')
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_OTDA_classes.rst b/docs/source/auto_examples/plot_OTDA_classes.rst new file mode 100644 index 0000000..097e9fc --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_classes.rst @@ -0,0 +1,190 @@ + + +.. _sphx_glr_auto_examples_plot_OTDA_classes.py: + + +======================== +OT for domain adaptation +======================== + + + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OTDA_classes_001.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OTDA_classes_004.png + :scale: 47 + + +.. rst-class:: sphx-glr-script-out + + Out:: + + It. |Loss |Delta loss + -------------------------------- + 0|9.171271e+00|0.000000e+00 + 1|2.133783e+00|-3.298127e+00 + 2|1.895941e+00|-1.254484e-01 + 3|1.844628e+00|-2.781709e-02 + 4|1.824983e+00|-1.076467e-02 + 5|1.815453e+00|-5.249337e-03 + 6|1.808104e+00|-4.064733e-03 + 7|1.803558e+00|-2.520475e-03 + 8|1.801061e+00|-1.386155e-03 + 9|1.799391e+00|-9.279565e-04 + 10|1.797176e+00|-1.232778e-03 + 11|1.795465e+00|-9.529479e-04 + 12|1.795316e+00|-8.322362e-05 + 13|1.794523e+00|-4.418932e-04 + 14|1.794444e+00|-4.390599e-05 + 15|1.794395e+00|-2.710318e-05 + 16|1.793713e+00|-3.804028e-04 + 17|1.793110e+00|-3.359479e-04 + 18|1.792829e+00|-1.569563e-04 + 19|1.792621e+00|-1.159469e-04 + It. |Loss |Delta loss + -------------------------------- + 20|1.791334e+00|-7.187689e-04 + + + + +| + + +.. code-block:: python + + + import matplotlib.pylab as pl + import ot + + + + + #%% parameters + + n=150 # nb samples in source and target datasets + + xs,ys=ot.datasets.get_data_classif('3gauss',n) + xt,yt=ot.datasets.get_data_classif('3gauss2',n) + + + + + #%% plot samples + + pl.figure(1) + + pl.subplot(2,2,1) + pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') + pl.legend(loc=0) + pl.title('Source distributions') + + pl.subplot(2,2,2) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + pl.legend(loc=0) + pl.title('target distributions') + + + #%% OT estimation + + # LP problem + da_emd=ot.da.OTDA() # init class + da_emd.fit(xs,xt) # fit distributions + xst0=da_emd.interp() # interpolation of source samples + + + # sinkhorn regularization + lambd=1e-1 + da_entrop=ot.da.OTDA_sinkhorn() + da_entrop.fit(xs,xt,reg=lambd) + xsts=da_entrop.interp() + + # non-convex Group lasso regularization + reg=1e-1 + eta=1e0 + da_lpl1=ot.da.OTDA_lpl1() + da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta) + xstg=da_lpl1.interp() + + + # True Group lasso regularization + reg=1e-1 + eta=2e0 + da_l1l2=ot.da.OTDA_l1l2() + da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True) + xstgl=da_l1l2.interp() + + + #%% plot interpolated source samples + pl.figure(4,(15,8)) + + param_img={'interpolation':'nearest','cmap':'jet'} + + pl.subplot(2,4,1) + pl.imshow(da_emd.G,**param_img) + pl.title('OT matrix') + + + pl.subplot(2,4,2) + pl.imshow(da_entrop.G,**param_img) + pl.title('OT matrix sinkhorn') + + pl.subplot(2,4,3) + pl.imshow(da_lpl1.G,**param_img) + pl.title('OT matrix non-convex Group Lasso') + + pl.subplot(2,4,4) + pl.imshow(da_l1l2.G,**param_img) + pl.title('OT matrix Group Lasso') + + + pl.subplot(2,4,5) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) + pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30) + pl.title('Interp samples') + pl.legend(loc=0) + + pl.subplot(2,4,6) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) + pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30) + pl.title('Interp samples Sinkhorn') + + pl.subplot(2,4,7) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) + pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30) + pl.title('Interp samples non-convex Group Lasso') + + pl.subplot(2,4,8) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) + pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30) + pl.title('Interp samples Group Lasso') +**Total running time of the script:** ( 0 minutes 2.225 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_OTDA_classes.py <plot_OTDA_classes.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_OTDA_classes.ipynb <plot_OTDA_classes.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/plot_OTDA_color_images.ipynb b/docs/source/auto_examples/plot_OTDA_color_images.ipynb new file mode 100644 index 0000000..d174828 --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_color_images.ipynb @@ -0,0 +1,54 @@ +{ + "nbformat_minor": 0, + "nbformat": 4, + "cells": [ + { + "execution_count": null, + "cell_type": "code", + "source": [ + "%matplotlib inline" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "\n========================================================\nOT for domain adaptation with image color adaptation [6]\n========================================================\n\n[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "import numpy as np\nimport scipy.ndimage as spi\nimport matplotlib.pylab as pl\nimport ot\n\n\n#%% Loading images\n\nI1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256\nI2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256\n\n#%% Plot images\n\npl.figure(1)\n\npl.subplot(1,2,1)\npl.imshow(I1)\npl.title('Image 1')\n\npl.subplot(1,2,2)\npl.imshow(I2)\npl.title('Image 2')\n\npl.show()\n\n#%% Image conversion and dataset generation\n\ndef im2mat(I):\n \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\n return I.reshape((I.shape[0]*I.shape[1],I.shape[2]))\n\ndef mat2im(X,shape):\n \"\"\"Converts back a matrix to an image\"\"\"\n return X.reshape(shape)\n\nX1=im2mat(I1)\nX2=im2mat(I2)\n\n# training samples\nnb=1000\nidx1=np.random.randint(X1.shape[0],size=(nb,))\nidx2=np.random.randint(X2.shape[0],size=(nb,))\n\nxs=X1[idx1,:]\nxt=X2[idx2,:]\n\n#%% Plot image distributions\n\n\npl.figure(2,(10,5))\n\npl.subplot(1,2,1)\npl.scatter(xs[:,0],xs[:,2],c=xs)\npl.axis([0,1,0,1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 1')\n\npl.subplot(1,2,2)\n#pl.imshow(I2)\npl.scatter(xt[:,0],xt[:,2],c=xt)\npl.axis([0,1,0,1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 2')\n\npl.show()\n\n\n\n#%% domain adaptation between images\n\n# LP problem\nda_emd=ot.da.OTDA() # init class\nda_emd.fit(xs,xt) # fit distributions\n\n\n# sinkhorn regularization\nlambd=1e-1\nda_entrop=ot.da.OTDA_sinkhorn()\nda_entrop.fit(xs,xt,reg=lambd)\n\n\n\n#%% prediction between images (using out of sample prediction as in [6])\n\nX1t=da_emd.predict(X1)\nX2t=da_emd.predict(X2,-1)\n\n\nX1te=da_entrop.predict(X1)\nX2te=da_entrop.predict(X2,-1)\n\n\ndef minmax(I):\n return np.minimum(np.maximum(I,0),1)\n\nI1t=minmax(mat2im(X1t,I1.shape))\nI2t=minmax(mat2im(X2t,I2.shape))\n\nI1te=minmax(mat2im(X1te,I1.shape))\nI2te=minmax(mat2im(X2te,I2.shape))\n\n#%% plot all images\n\npl.figure(2,(10,8))\n\npl.subplot(2,3,1)\n\npl.imshow(I1)\npl.title('Image 1')\n\npl.subplot(2,3,2)\npl.imshow(I1t)\npl.title('Image 1 Adapt')\n\n\npl.subplot(2,3,3)\npl.imshow(I1te)\npl.title('Image 1 Adapt (reg)')\n\npl.subplot(2,3,4)\n\npl.imshow(I2)\npl.title('Image 2')\n\npl.subplot(2,3,5)\npl.imshow(I2t)\npl.title('Image 2 Adapt')\n\n\npl.subplot(2,3,6)\npl.imshow(I2te)\npl.title('Image 2 Adapt (reg)')\n\npl.show()" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "name": "python2", + "language": "python" + }, + "language_info": { + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "name": "python", + "file_extension": ".py", + "version": "2.7.12", + "pygments_lexer": "ipython2", + "codemirror_mode": { + "version": 2, + "name": "ipython" + } + } + } +}
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_OTDA_color_images.py b/docs/source/auto_examples/plot_OTDA_color_images.py new file mode 100644 index 0000000..68eee44 --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_color_images.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +""" +======================================================== +OT for domain adaptation with image color adaptation [6] +======================================================== + +[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. +""" + +import numpy as np +import scipy.ndimage as spi +import matplotlib.pylab as pl +import ot + + +#%% Loading images + +I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256 +I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256 + +#%% Plot images + +pl.figure(1) + +pl.subplot(1,2,1) +pl.imshow(I1) +pl.title('Image 1') + +pl.subplot(1,2,2) +pl.imshow(I2) +pl.title('Image 2') + +pl.show() + +#%% Image conversion and dataset generation + +def im2mat(I): + """Converts and image to matrix (one pixel per line)""" + return I.reshape((I.shape[0]*I.shape[1],I.shape[2])) + +def mat2im(X,shape): + """Converts back a matrix to an image""" + return X.reshape(shape) + +X1=im2mat(I1) +X2=im2mat(I2) + +# training samples +nb=1000 +idx1=np.random.randint(X1.shape[0],size=(nb,)) +idx2=np.random.randint(X2.shape[0],size=(nb,)) + +xs=X1[idx1,:] +xt=X2[idx2,:] + +#%% Plot image distributions + + +pl.figure(2,(10,5)) + +pl.subplot(1,2,1) +pl.scatter(xs[:,0],xs[:,2],c=xs) +pl.axis([0,1,0,1]) +pl.xlabel('Red') +pl.ylabel('Blue') +pl.title('Image 1') + +pl.subplot(1,2,2) +#pl.imshow(I2) +pl.scatter(xt[:,0],xt[:,2],c=xt) +pl.axis([0,1,0,1]) +pl.xlabel('Red') +pl.ylabel('Blue') +pl.title('Image 2') + +pl.show() + + + +#%% domain adaptation between images + +# LP problem +da_emd=ot.da.OTDA() # init class +da_emd.fit(xs,xt) # fit distributions + + +# sinkhorn regularization +lambd=1e-1 +da_entrop=ot.da.OTDA_sinkhorn() +da_entrop.fit(xs,xt,reg=lambd) + + + +#%% prediction between images (using out of sample prediction as in [6]) + +X1t=da_emd.predict(X1) +X2t=da_emd.predict(X2,-1) + + +X1te=da_entrop.predict(X1) +X2te=da_entrop.predict(X2,-1) + + +def minmax(I): + return np.minimum(np.maximum(I,0),1) + +I1t=minmax(mat2im(X1t,I1.shape)) +I2t=minmax(mat2im(X2t,I2.shape)) + +I1te=minmax(mat2im(X1te,I1.shape)) +I2te=minmax(mat2im(X2te,I2.shape)) + +#%% plot all images + +pl.figure(2,(10,8)) + +pl.subplot(2,3,1) + +pl.imshow(I1) +pl.title('Image 1') + +pl.subplot(2,3,2) +pl.imshow(I1t) +pl.title('Image 1 Adapt') + + +pl.subplot(2,3,3) +pl.imshow(I1te) +pl.title('Image 1 Adapt (reg)') + +pl.subplot(2,3,4) + +pl.imshow(I2) +pl.title('Image 2') + +pl.subplot(2,3,5) +pl.imshow(I2t) +pl.title('Image 2 Adapt') + + +pl.subplot(2,3,6) +pl.imshow(I2te) +pl.title('Image 2 Adapt (reg)') + +pl.show() diff --git a/docs/source/auto_examples/plot_OTDA_color_images.rst b/docs/source/auto_examples/plot_OTDA_color_images.rst new file mode 100644 index 0000000..a982a90 --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_color_images.rst @@ -0,0 +1,191 @@ + + +.. _sphx_glr_auto_examples_plot_OTDA_color_images.py: + + +======================================================== +OT for domain adaptation with image color adaptation [6] +======================================================== + +[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OTDA_color_images_001.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OTDA_color_images_002.png + :scale: 47 + + + + + +.. code-block:: python + + + import numpy as np + import scipy.ndimage as spi + import matplotlib.pylab as pl + import ot + + + #%% Loading images + + I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256 + I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256 + + #%% Plot images + + pl.figure(1) + + pl.subplot(1,2,1) + pl.imshow(I1) + pl.title('Image 1') + + pl.subplot(1,2,2) + pl.imshow(I2) + pl.title('Image 2') + + pl.show() + + #%% Image conversion and dataset generation + + def im2mat(I): + """Converts and image to matrix (one pixel per line)""" + return I.reshape((I.shape[0]*I.shape[1],I.shape[2])) + + def mat2im(X,shape): + """Converts back a matrix to an image""" + return X.reshape(shape) + + X1=im2mat(I1) + X2=im2mat(I2) + + # training samples + nb=1000 + idx1=np.random.randint(X1.shape[0],size=(nb,)) + idx2=np.random.randint(X2.shape[0],size=(nb,)) + + xs=X1[idx1,:] + xt=X2[idx2,:] + + #%% Plot image distributions + + + pl.figure(2,(10,5)) + + pl.subplot(1,2,1) + pl.scatter(xs[:,0],xs[:,2],c=xs) + pl.axis([0,1,0,1]) + pl.xlabel('Red') + pl.ylabel('Blue') + pl.title('Image 1') + + pl.subplot(1,2,2) + #pl.imshow(I2) + pl.scatter(xt[:,0],xt[:,2],c=xt) + pl.axis([0,1,0,1]) + pl.xlabel('Red') + pl.ylabel('Blue') + pl.title('Image 2') + + pl.show() + + + + #%% domain adaptation between images + + # LP problem + da_emd=ot.da.OTDA() # init class + da_emd.fit(xs,xt) # fit distributions + + + # sinkhorn regularization + lambd=1e-1 + da_entrop=ot.da.OTDA_sinkhorn() + da_entrop.fit(xs,xt,reg=lambd) + + + + #%% prediction between images (using out of sample prediction as in [6]) + + X1t=da_emd.predict(X1) + X2t=da_emd.predict(X2,-1) + + + X1te=da_entrop.predict(X1) + X2te=da_entrop.predict(X2,-1) + + + def minmax(I): + return np.minimum(np.maximum(I,0),1) + + I1t=minmax(mat2im(X1t,I1.shape)) + I2t=minmax(mat2im(X2t,I2.shape)) + + I1te=minmax(mat2im(X1te,I1.shape)) + I2te=minmax(mat2im(X2te,I2.shape)) + + #%% plot all images + + pl.figure(2,(10,8)) + + pl.subplot(2,3,1) + + pl.imshow(I1) + pl.title('Image 1') + + pl.subplot(2,3,2) + pl.imshow(I1t) + pl.title('Image 1 Adapt') + + + pl.subplot(2,3,3) + pl.imshow(I1te) + pl.title('Image 1 Adapt (reg)') + + pl.subplot(2,3,4) + + pl.imshow(I2) + pl.title('Image 2') + + pl.subplot(2,3,5) + pl.imshow(I2t) + pl.title('Image 2 Adapt') + + + pl.subplot(2,3,6) + pl.imshow(I2te) + pl.title('Image 2 Adapt (reg)') + + pl.show() + +**Total running time of the script:** ( 0 minutes 24.815 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_OTDA_color_images.py <plot_OTDA_color_images.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_OTDA_color_images.ipynb <plot_OTDA_color_images.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/plot_OTDA_mapping.ipynb b/docs/source/auto_examples/plot_OTDA_mapping.ipynb new file mode 100644 index 0000000..ec405af --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_mapping.ipynb @@ -0,0 +1,54 @@ +{ + "nbformat_minor": 0, + "nbformat": 4, + "cells": [ + { + "execution_count": null, + "cell_type": "code", + "source": [ + "%matplotlib inline" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "\n===============================================\nOT mapping estimation for domain adaptation [8]\n===============================================\n\n[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, \"Mapping estimation for\n discrete optimal transport\", Neural Information Processing Systems (NIPS), 2016.\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n\n\n#%% dataset generation\n\nnp.random.seed(0) # makes example reproducible\n\nn=100 # nb samples in source and target datasets\ntheta=2*np.pi/20\nnz=0.1\nxs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz)\nxt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz)\n\n# one of the target mode changes its variance (no linear mapping)\nxt[yt==2]*=3\nxt=xt+4\n\n\n#%% plot samples\n\npl.figure(1,(8,5))\npl.clf()\n\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\n\npl.legend(loc=0)\npl.title('Source and target distributions')\n\n\n\n#%% OT linear mapping estimation\n\neta=1e-8 # quadratic regularization for regression\nmu=1e0 # weight of the OT linear term\nbias=True # estimate a bias\n\not_mapping=ot.da.OTDA_mapping_linear()\not_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True)\n\nxst=ot_mapping.predict(xs) # use the estimated mapping\nxst0=ot_mapping.interp() # use barycentric mapping\n\n\npl.figure(2,(10,7))\npl.clf()\npl.subplot(2,2,1)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)\npl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping')\npl.title(\"barycentric mapping\")\n\npl.subplot(2,2,2)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)\npl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')\npl.title(\"Learned mapping\")\n\n\n\n#%% Kernel mapping estimation\n\neta=1e-5 # quadratic regularization for regression\nmu=1e-1 # weight of the OT linear term\nbias=True # estimate a bias\nsigma=1 # sigma bandwidth fot gaussian kernel\n\n\not_mapping_kernel=ot.da.OTDA_mapping_kernel()\not_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True)\n\nxst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping\nxst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping\n\n\n#%% Plotting the mapped samples\n\npl.figure(2,(10,7))\npl.clf()\npl.subplot(2,2,1)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)\npl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples')\npl.title(\"Bary. mapping (linear)\")\npl.legend(loc=0)\n\npl.subplot(2,2,2)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)\npl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')\npl.title(\"Estim. mapping (linear)\")\n\npl.subplot(2,2,3)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)\npl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping')\npl.title(\"Bary. mapping (kernel)\")\n\npl.subplot(2,2,4)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)\npl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping')\npl.title(\"Estim. mapping (kernel)\")" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "name": "python2", + "language": "python" + }, + "language_info": { + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "name": "python", + "file_extension": ".py", + "version": "2.7.12", + "pygments_lexer": "ipython2", + "codemirror_mode": { + "version": 2, + "name": "ipython" + } + } + } +}
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_OTDA_mapping.py b/docs/source/auto_examples/plot_OTDA_mapping.py new file mode 100644 index 0000000..78b57e7 --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_mapping.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +""" +=============================================== +OT mapping estimation for domain adaptation [8] +=============================================== + +[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for + discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + + + +#%% dataset generation + +np.random.seed(0) # makes example reproducible + +n=100 # nb samples in source and target datasets +theta=2*np.pi/20 +nz=0.1 +xs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz) +xt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz) + +# one of the target mode changes its variance (no linear mapping) +xt[yt==2]*=3 +xt=xt+4 + + +#%% plot samples + +pl.figure(1,(8,5)) +pl.clf() + +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + +pl.legend(loc=0) +pl.title('Source and target distributions') + + + +#%% OT linear mapping estimation + +eta=1e-8 # quadratic regularization for regression +mu=1e0 # weight of the OT linear term +bias=True # estimate a bias + +ot_mapping=ot.da.OTDA_mapping_linear() +ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True) + +xst=ot_mapping.predict(xs) # use the estimated mapping +xst0=ot_mapping.interp() # use barycentric mapping + + +pl.figure(2,(10,7)) +pl.clf() +pl.subplot(2,2,1) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping') +pl.title("barycentric mapping") + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) +pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Learned mapping") + + + +#%% Kernel mapping estimation + +eta=1e-5 # quadratic regularization for regression +mu=1e-1 # weight of the OT linear term +bias=True # estimate a bias +sigma=1 # sigma bandwidth fot gaussian kernel + + +ot_mapping_kernel=ot.da.OTDA_mapping_kernel() +ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True) + +xst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping +xst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping + + +#%% Plotting the mapped samples + +pl.figure(2,(10,7)) +pl.clf() +pl.subplot(2,2,1) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples') +pl.title("Bary. mapping (linear)") +pl.legend(loc=0) + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Estim. mapping (linear)") + +pl.subplot(2,2,3) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping') +pl.title("Bary. mapping (kernel)") + +pl.subplot(2,2,4) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Estim. mapping (kernel)") diff --git a/docs/source/auto_examples/plot_OTDA_mapping.rst b/docs/source/auto_examples/plot_OTDA_mapping.rst new file mode 100644 index 0000000..18da90d --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_mapping.rst @@ -0,0 +1,186 @@ + + +.. _sphx_glr_auto_examples_plot_OTDA_mapping.py: + + +=============================================== +OT mapping estimation for domain adaptation [8] +=============================================== + +[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for + discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. + + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OTDA_mapping_001.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OTDA_mapping_002.png + :scale: 47 + + +.. rst-class:: sphx-glr-script-out + + Out:: + + It. |Loss |Delta loss + -------------------------------- + 0|4.009366e+03|0.000000e+00 + 1|3.999933e+03|-2.352753e-03 + 2|3.999520e+03|-1.031984e-04 + 3|3.999362e+03|-3.936391e-05 + 4|3.999281e+03|-2.032868e-05 + 5|3.999238e+03|-1.083083e-05 + 6|3.999229e+03|-2.125291e-06 + It. |Loss |Delta loss + -------------------------------- + 0|4.026841e+02|0.000000e+00 + 1|3.990791e+02|-8.952439e-03 + 2|3.987954e+02|-7.107124e-04 + 3|3.986554e+02|-3.512453e-04 + 4|3.985721e+02|-2.087997e-04 + 5|3.985141e+02|-1.456184e-04 + 6|3.984729e+02|-1.034624e-04 + 7|3.984435e+02|-7.366943e-05 + 8|3.984199e+02|-5.922497e-05 + 9|3.984016e+02|-4.593063e-05 + 10|3.983867e+02|-3.733061e-05 + + + + +| + + +.. code-block:: python + + + import numpy as np + import matplotlib.pylab as pl + import ot + + + + #%% dataset generation + + np.random.seed(0) # makes example reproducible + + n=100 # nb samples in source and target datasets + theta=2*np.pi/20 + nz=0.1 + xs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz) + xt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz) + + # one of the target mode changes its variance (no linear mapping) + xt[yt==2]*=3 + xt=xt+4 + + + #%% plot samples + + pl.figure(1,(8,5)) + pl.clf() + + pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + + pl.legend(loc=0) + pl.title('Source and target distributions') + + + + #%% OT linear mapping estimation + + eta=1e-8 # quadratic regularization for regression + mu=1e0 # weight of the OT linear term + bias=True # estimate a bias + + ot_mapping=ot.da.OTDA_mapping_linear() + ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True) + + xst=ot_mapping.predict(xs) # use the estimated mapping + xst0=ot_mapping.interp() # use barycentric mapping + + + pl.figure(2,(10,7)) + pl.clf() + pl.subplot(2,2,1) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) + pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping') + pl.title("barycentric mapping") + + pl.subplot(2,2,2) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) + pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') + pl.title("Learned mapping") + + + + #%% Kernel mapping estimation + + eta=1e-5 # quadratic regularization for regression + mu=1e-1 # weight of the OT linear term + bias=True # estimate a bias + sigma=1 # sigma bandwidth fot gaussian kernel + + + ot_mapping_kernel=ot.da.OTDA_mapping_kernel() + ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True) + + xst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping + xst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping + + + #%% Plotting the mapped samples + + pl.figure(2,(10,7)) + pl.clf() + pl.subplot(2,2,1) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) + pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples') + pl.title("Bary. mapping (linear)") + pl.legend(loc=0) + + pl.subplot(2,2,2) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) + pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') + pl.title("Estim. mapping (linear)") + + pl.subplot(2,2,3) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) + pl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping') + pl.title("Bary. mapping (kernel)") + + pl.subplot(2,2,4) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) + pl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping') + pl.title("Estim. mapping (kernel)") + +**Total running time of the script:** ( 0 minutes 0.882 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_OTDA_mapping.py <plot_OTDA_mapping.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_OTDA_mapping.ipynb <plot_OTDA_mapping.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/plot_OTDA_mapping_color_images.ipynb b/docs/source/auto_examples/plot_OTDA_mapping_color_images.ipynb new file mode 100644 index 0000000..1136cc3 --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_mapping_color_images.ipynb @@ -0,0 +1,54 @@ +{ + "nbformat_minor": 0, + "nbformat": 4, + "cells": [ + { + "execution_count": null, + "cell_type": "code", + "source": [ + "%matplotlib inline" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "\n====================================================================================\nOT for domain adaptation with image color adaptation [6] with mapping estimation [8]\n====================================================================================\n\n[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized\n discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.\n[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, \"Mapping estimation for\n discrete optimal transport\", Neural Information Processing Systems (NIPS), 2016.\n\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "import numpy as np\nimport scipy.ndimage as spi\nimport matplotlib.pylab as pl\nimport ot\n\n\n#%% Loading images\n\nI1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256\nI2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256\n\n#%% Plot images\n\npl.figure(1)\n\npl.subplot(1,2,1)\npl.imshow(I1)\npl.title('Image 1')\n\npl.subplot(1,2,2)\npl.imshow(I2)\npl.title('Image 2')\n\npl.show()\n\n#%% Image conversion and dataset generation\n\ndef im2mat(I):\n \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\n return I.reshape((I.shape[0]*I.shape[1],I.shape[2]))\n\ndef mat2im(X,shape):\n \"\"\"Converts back a matrix to an image\"\"\"\n return X.reshape(shape)\n\nX1=im2mat(I1)\nX2=im2mat(I2)\n\n# training samples\nnb=1000\nidx1=np.random.randint(X1.shape[0],size=(nb,))\nidx2=np.random.randint(X2.shape[0],size=(nb,))\n\nxs=X1[idx1,:]\nxt=X2[idx2,:]\n\n#%% Plot image distributions\n\n\npl.figure(2,(10,5))\n\npl.subplot(1,2,1)\npl.scatter(xs[:,0],xs[:,2],c=xs)\npl.axis([0,1,0,1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 1')\n\npl.subplot(1,2,2)\n#pl.imshow(I2)\npl.scatter(xt[:,0],xt[:,2],c=xt)\npl.axis([0,1,0,1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 2')\n\npl.show()\n\n\n\n#%% domain adaptation between images\ndef minmax(I):\n return np.minimum(np.maximum(I,0),1)\n# LP problem\nda_emd=ot.da.OTDA() # init class\nda_emd.fit(xs,xt) # fit distributions\n\nX1t=da_emd.predict(X1) # out of sample\nI1t=minmax(mat2im(X1t,I1.shape))\n\n# sinkhorn regularization\nlambd=1e-1\nda_entrop=ot.da.OTDA_sinkhorn()\nda_entrop.fit(xs,xt,reg=lambd)\n\nX1te=da_entrop.predict(X1)\nI1te=minmax(mat2im(X1te,I1.shape))\n\n# linear mapping estimation\neta=1e-8 # quadratic regularization for regression\nmu=1e0 # weight of the OT linear term\nbias=True # estimate a bias\n\not_mapping=ot.da.OTDA_mapping_linear()\not_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True)\n\nX1tl=ot_mapping.predict(X1) # use the estimated mapping\nI1tl=minmax(mat2im(X1tl,I1.shape))\n\n# nonlinear mapping estimation\neta=1e-2 # quadratic regularization for regression\nmu=1e0 # weight of the OT linear term\nbias=False # estimate a bias\nsigma=1 # sigma bandwidth fot gaussian kernel\n\n\not_mapping_kernel=ot.da.OTDA_mapping_kernel()\not_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True)\n\nX1tn=ot_mapping_kernel.predict(X1) # use the estimated mapping\nI1tn=minmax(mat2im(X1tn,I1.shape))\n#%% plot images\n\n\npl.figure(2,(10,8))\n\npl.subplot(2,3,1)\n\npl.imshow(I1)\npl.title('Im. 1')\n\npl.subplot(2,3,2)\n\npl.imshow(I2)\npl.title('Im. 2')\n\n\npl.subplot(2,3,3)\npl.imshow(I1t)\npl.title('Im. 1 Interp LP')\n\npl.subplot(2,3,4)\npl.imshow(I1te)\npl.title('Im. 1 Interp Entrop')\n\n\npl.subplot(2,3,5)\npl.imshow(I1tl)\npl.title('Im. 1 Linear mapping')\n\npl.subplot(2,3,6)\npl.imshow(I1tn)\npl.title('Im. 1 nonlinear mapping')\n\npl.show()" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "name": "python2", + "language": "python" + }, + "language_info": { + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "name": "python", + "file_extension": ".py", + "version": "2.7.12", + "pygments_lexer": "ipython2", + "codemirror_mode": { + "version": 2, + "name": "ipython" + } + } + } +}
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_OTDA_mapping_color_images.py b/docs/source/auto_examples/plot_OTDA_mapping_color_images.py new file mode 100644 index 0000000..f07dc6c --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_mapping_color_images.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +""" +==================================================================================== +OT for domain adaptation with image color adaptation [6] with mapping estimation [8] +==================================================================================== + +[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized + discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. +[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for + discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. + +""" + +import numpy as np +import scipy.ndimage as spi +import matplotlib.pylab as pl +import ot + + +#%% Loading images + +I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256 +I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256 + +#%% Plot images + +pl.figure(1) + +pl.subplot(1,2,1) +pl.imshow(I1) +pl.title('Image 1') + +pl.subplot(1,2,2) +pl.imshow(I2) +pl.title('Image 2') + +pl.show() + +#%% Image conversion and dataset generation + +def im2mat(I): + """Converts and image to matrix (one pixel per line)""" + return I.reshape((I.shape[0]*I.shape[1],I.shape[2])) + +def mat2im(X,shape): + """Converts back a matrix to an image""" + return X.reshape(shape) + +X1=im2mat(I1) +X2=im2mat(I2) + +# training samples +nb=1000 +idx1=np.random.randint(X1.shape[0],size=(nb,)) +idx2=np.random.randint(X2.shape[0],size=(nb,)) + +xs=X1[idx1,:] +xt=X2[idx2,:] + +#%% Plot image distributions + + +pl.figure(2,(10,5)) + +pl.subplot(1,2,1) +pl.scatter(xs[:,0],xs[:,2],c=xs) +pl.axis([0,1,0,1]) +pl.xlabel('Red') +pl.ylabel('Blue') +pl.title('Image 1') + +pl.subplot(1,2,2) +#pl.imshow(I2) +pl.scatter(xt[:,0],xt[:,2],c=xt) +pl.axis([0,1,0,1]) +pl.xlabel('Red') +pl.ylabel('Blue') +pl.title('Image 2') + +pl.show() + + + +#%% domain adaptation between images +def minmax(I): + return np.minimum(np.maximum(I,0),1) +# LP problem +da_emd=ot.da.OTDA() # init class +da_emd.fit(xs,xt) # fit distributions + +X1t=da_emd.predict(X1) # out of sample +I1t=minmax(mat2im(X1t,I1.shape)) + +# sinkhorn regularization +lambd=1e-1 +da_entrop=ot.da.OTDA_sinkhorn() +da_entrop.fit(xs,xt,reg=lambd) + +X1te=da_entrop.predict(X1) +I1te=minmax(mat2im(X1te,I1.shape)) + +# linear mapping estimation +eta=1e-8 # quadratic regularization for regression +mu=1e0 # weight of the OT linear term +bias=True # estimate a bias + +ot_mapping=ot.da.OTDA_mapping_linear() +ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True) + +X1tl=ot_mapping.predict(X1) # use the estimated mapping +I1tl=minmax(mat2im(X1tl,I1.shape)) + +# nonlinear mapping estimation +eta=1e-2 # quadratic regularization for regression +mu=1e0 # weight of the OT linear term +bias=False # estimate a bias +sigma=1 # sigma bandwidth fot gaussian kernel + + +ot_mapping_kernel=ot.da.OTDA_mapping_kernel() +ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True) + +X1tn=ot_mapping_kernel.predict(X1) # use the estimated mapping +I1tn=minmax(mat2im(X1tn,I1.shape)) +#%% plot images + + +pl.figure(2,(10,8)) + +pl.subplot(2,3,1) + +pl.imshow(I1) +pl.title('Im. 1') + +pl.subplot(2,3,2) + +pl.imshow(I2) +pl.title('Im. 2') + + +pl.subplot(2,3,3) +pl.imshow(I1t) +pl.title('Im. 1 Interp LP') + +pl.subplot(2,3,4) +pl.imshow(I1te) +pl.title('Im. 1 Interp Entrop') + + +pl.subplot(2,3,5) +pl.imshow(I1tl) +pl.title('Im. 1 Linear mapping') + +pl.subplot(2,3,6) +pl.imshow(I1tn) +pl.title('Im. 1 nonlinear mapping') + +pl.show() diff --git a/docs/source/auto_examples/plot_OTDA_mapping_color_images.rst b/docs/source/auto_examples/plot_OTDA_mapping_color_images.rst new file mode 100644 index 0000000..60be3a4 --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_mapping_color_images.rst @@ -0,0 +1,246 @@ + + +.. _sphx_glr_auto_examples_plot_OTDA_mapping_color_images.py: + + +==================================================================================== +OT for domain adaptation with image color adaptation [6] with mapping estimation [8] +==================================================================================== + +[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized + discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. +[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for + discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. + + + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_001.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_002.png + :scale: 47 + + +.. rst-class:: sphx-glr-script-out + + Out:: + + It. |Loss |Delta loss + -------------------------------- + 0|3.624802e+02|0.000000e+00 + 1|3.547180e+02|-2.141395e-02 + 2|3.545494e+02|-4.753955e-04 + 3|3.544646e+02|-2.391784e-04 + 4|3.544126e+02|-1.466280e-04 + 5|3.543775e+02|-9.921805e-05 + 6|3.543518e+02|-7.245828e-05 + 7|3.543323e+02|-5.491924e-05 + 8|3.543170e+02|-4.342401e-05 + 9|3.543046e+02|-3.472174e-05 + 10|3.542945e+02|-2.878681e-05 + 11|3.542859e+02|-2.417065e-05 + 12|3.542786e+02|-2.058131e-05 + 13|3.542723e+02|-1.768262e-05 + 14|3.542668e+02|-1.551616e-05 + 15|3.542620e+02|-1.371909e-05 + 16|3.542577e+02|-1.213326e-05 + 17|3.542538e+02|-1.085481e-05 + 18|3.542531e+02|-1.996006e-06 + It. |Loss |Delta loss + -------------------------------- + 0|3.555768e+02|0.000000e+00 + 1|3.510071e+02|-1.285164e-02 + 2|3.509110e+02|-2.736701e-04 + 3|3.508748e+02|-1.031476e-04 + 4|3.508506e+02|-6.910585e-05 + 5|3.508330e+02|-5.014608e-05 + 6|3.508195e+02|-3.839166e-05 + 7|3.508090e+02|-3.004218e-05 + 8|3.508005e+02|-2.417627e-05 + 9|3.507935e+02|-2.004621e-05 + 10|3.507876e+02|-1.681731e-05 + + + + +| + + +.. code-block:: python + + + import numpy as np + import scipy.ndimage as spi + import matplotlib.pylab as pl + import ot + + + #%% Loading images + + I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256 + I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256 + + #%% Plot images + + pl.figure(1) + + pl.subplot(1,2,1) + pl.imshow(I1) + pl.title('Image 1') + + pl.subplot(1,2,2) + pl.imshow(I2) + pl.title('Image 2') + + pl.show() + + #%% Image conversion and dataset generation + + def im2mat(I): + """Converts and image to matrix (one pixel per line)""" + return I.reshape((I.shape[0]*I.shape[1],I.shape[2])) + + def mat2im(X,shape): + """Converts back a matrix to an image""" + return X.reshape(shape) + + X1=im2mat(I1) + X2=im2mat(I2) + + # training samples + nb=1000 + idx1=np.random.randint(X1.shape[0],size=(nb,)) + idx2=np.random.randint(X2.shape[0],size=(nb,)) + + xs=X1[idx1,:] + xt=X2[idx2,:] + + #%% Plot image distributions + + + pl.figure(2,(10,5)) + + pl.subplot(1,2,1) + pl.scatter(xs[:,0],xs[:,2],c=xs) + pl.axis([0,1,0,1]) + pl.xlabel('Red') + pl.ylabel('Blue') + pl.title('Image 1') + + pl.subplot(1,2,2) + #pl.imshow(I2) + pl.scatter(xt[:,0],xt[:,2],c=xt) + pl.axis([0,1,0,1]) + pl.xlabel('Red') + pl.ylabel('Blue') + pl.title('Image 2') + + pl.show() + + + + #%% domain adaptation between images + def minmax(I): + return np.minimum(np.maximum(I,0),1) + # LP problem + da_emd=ot.da.OTDA() # init class + da_emd.fit(xs,xt) # fit distributions + + X1t=da_emd.predict(X1) # out of sample + I1t=minmax(mat2im(X1t,I1.shape)) + + # sinkhorn regularization + lambd=1e-1 + da_entrop=ot.da.OTDA_sinkhorn() + da_entrop.fit(xs,xt,reg=lambd) + + X1te=da_entrop.predict(X1) + I1te=minmax(mat2im(X1te,I1.shape)) + + # linear mapping estimation + eta=1e-8 # quadratic regularization for regression + mu=1e0 # weight of the OT linear term + bias=True # estimate a bias + + ot_mapping=ot.da.OTDA_mapping_linear() + ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True) + + X1tl=ot_mapping.predict(X1) # use the estimated mapping + I1tl=minmax(mat2im(X1tl,I1.shape)) + + # nonlinear mapping estimation + eta=1e-2 # quadratic regularization for regression + mu=1e0 # weight of the OT linear term + bias=False # estimate a bias + sigma=1 # sigma bandwidth fot gaussian kernel + + + ot_mapping_kernel=ot.da.OTDA_mapping_kernel() + ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True) + + X1tn=ot_mapping_kernel.predict(X1) # use the estimated mapping + I1tn=minmax(mat2im(X1tn,I1.shape)) + #%% plot images + + + pl.figure(2,(10,8)) + + pl.subplot(2,3,1) + + pl.imshow(I1) + pl.title('Im. 1') + + pl.subplot(2,3,2) + + pl.imshow(I2) + pl.title('Im. 2') + + + pl.subplot(2,3,3) + pl.imshow(I1t) + pl.title('Im. 1 Interp LP') + + pl.subplot(2,3,4) + pl.imshow(I1te) + pl.title('Im. 1 Interp Entrop') + + + pl.subplot(2,3,5) + pl.imshow(I1tl) + pl.title('Im. 1 Linear mapping') + + pl.subplot(2,3,6) + pl.imshow(I1tn) + pl.title('Im. 1 nonlinear mapping') + + pl.show() + +**Total running time of the script:** ( 1 minutes 59.537 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_OTDA_mapping_color_images.py <plot_OTDA_mapping_color_images.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_OTDA_mapping_color_images.ipynb <plot_OTDA_mapping_color_images.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/plot_OT_1D.ipynb b/docs/source/auto_examples/plot_OT_1D.ipynb new file mode 100644 index 0000000..17d0b21 --- /dev/null +++ b/docs/source/auto_examples/plot_OT_1D.ipynb @@ -0,0 +1,54 @@ +{ + "nbformat_minor": 0, + "nbformat": 4, + "cells": [ + { + "execution_count": null, + "cell_type": "code", + "source": [ + "%matplotlib inline" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "\n# 1D optimal transport\n\n\n@author: rflamary\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss\n\n\n#%% parameters\n\nn=100 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\n# Gaussian distributions\na=gauss(n,m=20,s=5) # m= mean, s= std\nb=gauss(n,m=60,s=10)\n\n# loss matrix\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)))\nM/=M.max()\n\n#%% plot the distributions\n\npl.figure(1)\npl.plot(x,a,'b',label='Source distribution')\npl.plot(x,b,'r',label='Target distribution')\npl.legend()\n\n#%% plot distributions and loss matrix\n\npl.figure(2)\not.plot.plot1D_mat(a,b,M,'Cost matrix M')\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,G0,'OT matrix G0')\n\n#%% Sinkhorn\n\nlambd=1e-3\nGs=ot.sinkhorn(a,b,M,lambd)\n\npl.figure(4)\not.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "name": "python2", + "language": "python" + }, + "language_info": { + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "name": "python", + "file_extension": ".py", + "version": "2.7.12", + "pygments_lexer": "ipython2", + "codemirror_mode": { + "version": 2, + "name": "ipython" + } + } + } +}
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_OT_1D.py b/docs/source/auto_examples/plot_OT_1D.py new file mode 100644 index 0000000..e5719eb --- /dev/null +++ b/docs/source/auto_examples/plot_OT_1D.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +""" +==================== +1D optimal transport +==================== + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +import ot +from ot.datasets import get_1D_gauss as gauss + + +#%% parameters + +n=100 # nb bins + +# bin positions +x=np.arange(n,dtype=np.float64) + +# Gaussian distributions +a=gauss(n,m=20,s=5) # m= mean, s= std +b=gauss(n,m=60,s=10) + +# loss matrix +M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) +M/=M.max() + +#%% plot the distributions + +pl.figure(1) +pl.plot(x,a,'b',label='Source distribution') +pl.plot(x,b,'r',label='Target distribution') +pl.legend() + +#%% plot distributions and loss matrix + +pl.figure(2) +ot.plot.plot1D_mat(a,b,M,'Cost matrix M') + +#%% EMD + +G0=ot.emd(a,b,M) + +pl.figure(3) +ot.plot.plot1D_mat(a,b,G0,'OT matrix G0') + +#%% Sinkhorn + +lambd=1e-3 +Gs=ot.sinkhorn(a,b,M,lambd) + +pl.figure(4) +ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn') diff --git a/docs/source/auto_examples/plot_OT_1D.rst b/docs/source/auto_examples/plot_OT_1D.rst new file mode 100644 index 0000000..941fd54 --- /dev/null +++ b/docs/source/auto_examples/plot_OT_1D.rst @@ -0,0 +1,112 @@ + + +.. _sphx_glr_auto_examples_plot_OT_1D.py: + + +==================== +1D optimal transport +==================== + +@author: rflamary + + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_1D_001.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_1D_002.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_1D_003.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_1D_004.png + :scale: 47 + + + + + +.. code-block:: python + + + import numpy as np + import matplotlib.pylab as pl + import ot + from ot.datasets import get_1D_gauss as gauss + + + #%% parameters + + n=100 # nb bins + + # bin positions + x=np.arange(n,dtype=np.float64) + + # Gaussian distributions + a=gauss(n,m=20,s=5) # m= mean, s= std + b=gauss(n,m=60,s=10) + + # loss matrix + M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) + M/=M.max() + + #%% plot the distributions + + pl.figure(1) + pl.plot(x,a,'b',label='Source distribution') + pl.plot(x,b,'r',label='Target distribution') + pl.legend() + + #%% plot distributions and loss matrix + + pl.figure(2) + ot.plot.plot1D_mat(a,b,M,'Cost matrix M') + + #%% EMD + + G0=ot.emd(a,b,M) + + pl.figure(3) + ot.plot.plot1D_mat(a,b,G0,'OT matrix G0') + + #%% Sinkhorn + + lambd=1e-3 + Gs=ot.sinkhorn(a,b,M,lambd) + + pl.figure(4) + ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn') + +**Total running time of the script:** ( 0 minutes 0.597 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_OT_1D.py <plot_OT_1D.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_OT_1D.ipynb <plot_OT_1D.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/plot_OT_2D_samples.ipynb b/docs/source/auto_examples/plot_OT_2D_samples.ipynb new file mode 100644 index 0000000..e8ec1d1 --- /dev/null +++ b/docs/source/auto_examples/plot_OT_2D_samples.ipynb @@ -0,0 +1,54 @@ +{ + "nbformat_minor": 0, + "nbformat": 4, + "cells": [ + { + "execution_count": null, + "cell_type": "code", + "source": [ + "%matplotlib inline" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "\n# 2D Optimal transport between empirical distributions\n\n\n@author: rflamary\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n#%% parameters and data generation\n\nn=20 # nb samples\n\nmu_s=np.array([0,0])\ncov_s=np.array([[1,0],[0,1]])\n\nmu_t=np.array([4,4])\ncov_t=np.array([[1,-.8],[-.8,1]])\n\nxs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)\nxt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)\n\na,b = ot.unif(n),ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM=ot.dist(xs,xt)\nM/=M.max()\n\n#%% plot samples\n\npl.figure(1)\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('Source and traget distributions')\n\npl.figure(2)\npl.imshow(M,interpolation='nearest')\npl.title('Cost matrix M')\n\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\npl.imshow(G0,interpolation='nearest')\npl.title('OT matrix G0')\n\npl.figure(4)\not.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix with samples')\n\n\n#%% sinkhorn\n\n# reg term\nlambd=5e-3\n\nGs=ot.sinkhorn(a,b,M,lambd)\n\npl.figure(5)\npl.imshow(Gs,interpolation='nearest')\npl.title('OT matrix sinkhorn')\n\npl.figure(6)\not.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn with samples')" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "name": "python2", + "language": "python" + }, + "language_info": { + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "name": "python", + "file_extension": ".py", + "version": "2.7.12", + "pygments_lexer": "ipython2", + "codemirror_mode": { + "version": 2, + "name": "ipython" + } + } + } +}
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_OT_2D_samples.py b/docs/source/auto_examples/plot_OT_2D_samples.py new file mode 100644 index 0000000..6c39ad4 --- /dev/null +++ b/docs/source/auto_examples/plot_OT_2D_samples.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +2D Optimal transport between empirical distributions +==================================================== + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + +#%% parameters and data generation + +n=20 # nb samples + +mu_s=np.array([0,0]) +cov_s=np.array([[1,0],[0,1]]) + +mu_t=np.array([4,4]) +cov_t=np.array([[1,-.8],[-.8,1]]) + +xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s) +xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t) + +a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples + +# loss matrix +M=ot.dist(xs,xt) +M/=M.max() + +#%% plot samples + +pl.figure(1) +pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') +pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') +pl.legend(loc=0) +pl.title('Source and traget distributions') + +pl.figure(2) +pl.imshow(M,interpolation='nearest') +pl.title('Cost matrix M') + + +#%% EMD + +G0=ot.emd(a,b,M) + +pl.figure(3) +pl.imshow(G0,interpolation='nearest') +pl.title('OT matrix G0') + +pl.figure(4) +ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) +pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') +pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') +pl.legend(loc=0) +pl.title('OT matrix with samples') + + +#%% sinkhorn + +# reg term +lambd=5e-3 + +Gs=ot.sinkhorn(a,b,M,lambd) + +pl.figure(5) +pl.imshow(Gs,interpolation='nearest') +pl.title('OT matrix sinkhorn') + +pl.figure(6) +ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1]) +pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') +pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') +pl.legend(loc=0) +pl.title('OT matrix Sinkhorn with samples') diff --git a/docs/source/auto_examples/plot_OT_2D_samples.rst b/docs/source/auto_examples/plot_OT_2D_samples.rst new file mode 100644 index 0000000..bc86cb8 --- /dev/null +++ b/docs/source/auto_examples/plot_OT_2D_samples.rst @@ -0,0 +1,144 @@ + + +.. _sphx_glr_auto_examples_plot_OT_2D_samples.py: + + +==================================================== +2D Optimal transport between empirical distributions +==================================================== + +@author: rflamary + + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_003.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_004.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png + :scale: 47 + + + + + +.. code-block:: python + + + import numpy as np + import matplotlib.pylab as pl + import ot + + #%% parameters and data generation + + n=20 # nb samples + + mu_s=np.array([0,0]) + cov_s=np.array([[1,0],[0,1]]) + + mu_t=np.array([4,4]) + cov_t=np.array([[1,-.8],[-.8,1]]) + + xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s) + xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t) + + a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples + + # loss matrix + M=ot.dist(xs,xt) + M/=M.max() + + #%% plot samples + + pl.figure(1) + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.legend(loc=0) + pl.title('Source and traget distributions') + + pl.figure(2) + pl.imshow(M,interpolation='nearest') + pl.title('Cost matrix M') + + + #%% EMD + + G0=ot.emd(a,b,M) + + pl.figure(3) + pl.imshow(G0,interpolation='nearest') + pl.title('OT matrix G0') + + pl.figure(4) + ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.legend(loc=0) + pl.title('OT matrix with samples') + + + #%% sinkhorn + + # reg term + lambd=5e-3 + + Gs=ot.sinkhorn(a,b,M,lambd) + + pl.figure(5) + pl.imshow(Gs,interpolation='nearest') + pl.title('OT matrix sinkhorn') + + pl.figure(6) + ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1]) + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.legend(loc=0) + pl.title('OT matrix Sinkhorn with samples') + +**Total running time of the script:** ( 0 minutes 1.051 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_OT_2D_samples.py <plot_OT_2D_samples.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_OT_2D_samples.ipynb <plot_OT_2D_samples.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/plot_barycenter_1D.ipynb b/docs/source/auto_examples/plot_barycenter_1D.ipynb new file mode 100644 index 0000000..36f3975 --- /dev/null +++ b/docs/source/auto_examples/plot_barycenter_1D.ipynb @@ -0,0 +1,54 @@ +{ + "nbformat_minor": 0, + "nbformat": 4, + "cells": [ + { + "execution_count": null, + "cell_type": "code", + "source": [ + "%matplotlib inline" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "\n# 1D Wasserstein barycenter demo\n\n\n\n@author: rflamary\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used\nfrom matplotlib.collections import PolyCollection\n\n\n#%% parameters\n\nn=100 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\n# Gaussian distributions\na1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std\na2=ot.datasets.get_1D_gauss(n,m=60,s=8)\n\n# creating matrix A containing all distributions\nA=np.vstack((a1,a2)).T\nnbd=A.shape[1]\n\n# loss matrix + normalization\nM=ot.utils.dist0(n)\nM/=M.max()\n\n#%% plot the distributions\n\npl.figure(1)\nfor i in range(nbd):\n pl.plot(x,A[:,i])\npl.title('Distributions')\n\n#%% barycenter computation\n\nalpha=0.2 # 0<=alpha<=1\nweights=np.array([1-alpha,alpha])\n\n# l2bary\nbary_l2=A.dot(weights)\n\n# wasserstein\nreg=1e-3\nbary_wass=ot.bregman.barycenter(A,M,reg,weights)\n\npl.figure(2)\npl.clf()\npl.subplot(2,1,1)\nfor i in range(nbd):\n pl.plot(x,A[:,i])\npl.title('Distributions')\n\npl.subplot(2,1,2)\npl.plot(x,bary_l2,'r',label='l2')\npl.plot(x,bary_wass,'g',label='Wasserstein')\npl.legend()\npl.title('Barycenters')\n\n\n#%% barycenter interpolation\n\nnbalpha=11\nalphalist=np.linspace(0,1,nbalpha)\n\n\nB_l2=np.zeros((n,nbalpha))\n\nB_wass=np.copy(B_l2)\n\nfor i in range(0,nbalpha):\n alpha=alphalist[i]\n weights=np.array([1-alpha,alpha])\n B_l2[:,i]=A.dot(weights)\n B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights)\n\n#%% plot interpolation\n\npl.figure(3,(10,5))\n\n#pl.subplot(1,2,1)\ncmap=pl.cm.get_cmap('viridis')\nverts = []\nzs = alphalist\nfor i,z in enumerate(zs):\n ys = B_l2[:,i]\n verts.append(list(zip(x, ys)))\n\nax = pl.gcf().gca(projection='3d')\n\npoly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir='y')\n\nax.set_xlabel('x')\nax.set_xlim3d(0, n)\nax.set_ylabel('$\\\\alpha$')\nax.set_ylim3d(0,1)\nax.set_zlabel('')\nax.set_zlim3d(0, B_l2.max()*1.01)\npl.title('Barycenter interpolation with l2')\n\npl.show()\n\npl.figure(4,(10,5))\n\n#pl.subplot(1,2,1)\ncmap=pl.cm.get_cmap('viridis')\nverts = []\nzs = alphalist\nfor i,z in enumerate(zs):\n ys = B_wass[:,i]\n verts.append(list(zip(x, ys)))\n\nax = pl.gcf().gca(projection='3d')\n\npoly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir='y')\n\nax.set_xlabel('x')\nax.set_xlim3d(0, n)\nax.set_ylabel('$\\\\alpha$')\nax.set_ylim3d(0,1)\nax.set_zlabel('')\nax.set_zlim3d(0, B_l2.max()*1.01)\npl.title('Barycenter interpolation with Wasserstein')\n\npl.show()" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "name": "python2", + "language": "python" + }, + "language_info": { + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "name": "python", + "file_extension": ".py", + "version": "2.7.12", + "pygments_lexer": "ipython2", + "codemirror_mode": { + "version": 2, + "name": "ipython" + } + } + } +}
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_barycenter_1D.py b/docs/source/auto_examples/plot_barycenter_1D.py new file mode 100644 index 0000000..30eecbf --- /dev/null +++ b/docs/source/auto_examples/plot_barycenter_1D.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +""" +============================== +1D Wasserstein barycenter demo +============================== + + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +import ot +from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used +from matplotlib.collections import PolyCollection + + +#%% parameters + +n=100 # nb bins + +# bin positions +x=np.arange(n,dtype=np.float64) + +# Gaussian distributions +a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std +a2=ot.datasets.get_1D_gauss(n,m=60,s=8) + +# creating matrix A containing all distributions +A=np.vstack((a1,a2)).T +nbd=A.shape[1] + +# loss matrix + normalization +M=ot.utils.dist0(n) +M/=M.max() + +#%% plot the distributions + +pl.figure(1) +for i in range(nbd): + pl.plot(x,A[:,i]) +pl.title('Distributions') + +#%% barycenter computation + +alpha=0.2 # 0<=alpha<=1 +weights=np.array([1-alpha,alpha]) + +# l2bary +bary_l2=A.dot(weights) + +# wasserstein +reg=1e-3 +bary_wass=ot.bregman.barycenter(A,M,reg,weights) + +pl.figure(2) +pl.clf() +pl.subplot(2,1,1) +for i in range(nbd): + pl.plot(x,A[:,i]) +pl.title('Distributions') + +pl.subplot(2,1,2) +pl.plot(x,bary_l2,'r',label='l2') +pl.plot(x,bary_wass,'g',label='Wasserstein') +pl.legend() +pl.title('Barycenters') + + +#%% barycenter interpolation + +nbalpha=11 +alphalist=np.linspace(0,1,nbalpha) + + +B_l2=np.zeros((n,nbalpha)) + +B_wass=np.copy(B_l2) + +for i in range(0,nbalpha): + alpha=alphalist[i] + weights=np.array([1-alpha,alpha]) + B_l2[:,i]=A.dot(weights) + B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights) + +#%% plot interpolation + +pl.figure(3,(10,5)) + +#pl.subplot(1,2,1) +cmap=pl.cm.get_cmap('viridis') +verts = [] +zs = alphalist +for i,z in enumerate(zs): + ys = B_l2[:,i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') + +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel('$\\alpha$') +ax.set_ylim3d(0,1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max()*1.01) +pl.title('Barycenter interpolation with l2') + +pl.show() + +pl.figure(4,(10,5)) + +#pl.subplot(1,2,1) +cmap=pl.cm.get_cmap('viridis') +verts = [] +zs = alphalist +for i,z in enumerate(zs): + ys = B_wass[:,i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') + +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel('$\\alpha$') +ax.set_ylim3d(0,1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max()*1.01) +pl.title('Barycenter interpolation with Wasserstein') + +pl.show()
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_barycenter_1D.rst b/docs/source/auto_examples/plot_barycenter_1D.rst new file mode 100644 index 0000000..1b15c77 --- /dev/null +++ b/docs/source/auto_examples/plot_barycenter_1D.rst @@ -0,0 +1,193 @@ + + +.. _sphx_glr_auto_examples_plot_barycenter_1D.py: + + +============================== +1D Wasserstein barycenter demo +============================== + + +@author: rflamary + + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_001.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_002.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_003.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_004.png + :scale: 47 + + + + + +.. code-block:: python + + + import numpy as np + import matplotlib.pylab as pl + import ot + from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used + from matplotlib.collections import PolyCollection + + + #%% parameters + + n=100 # nb bins + + # bin positions + x=np.arange(n,dtype=np.float64) + + # Gaussian distributions + a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std + a2=ot.datasets.get_1D_gauss(n,m=60,s=8) + + # creating matrix A containing all distributions + A=np.vstack((a1,a2)).T + nbd=A.shape[1] + + # loss matrix + normalization + M=ot.utils.dist0(n) + M/=M.max() + + #%% plot the distributions + + pl.figure(1) + for i in range(nbd): + pl.plot(x,A[:,i]) + pl.title('Distributions') + + #%% barycenter computation + + alpha=0.2 # 0<=alpha<=1 + weights=np.array([1-alpha,alpha]) + + # l2bary + bary_l2=A.dot(weights) + + # wasserstein + reg=1e-3 + bary_wass=ot.bregman.barycenter(A,M,reg,weights) + + pl.figure(2) + pl.clf() + pl.subplot(2,1,1) + for i in range(nbd): + pl.plot(x,A[:,i]) + pl.title('Distributions') + + pl.subplot(2,1,2) + pl.plot(x,bary_l2,'r',label='l2') + pl.plot(x,bary_wass,'g',label='Wasserstein') + pl.legend() + pl.title('Barycenters') + + + #%% barycenter interpolation + + nbalpha=11 + alphalist=np.linspace(0,1,nbalpha) + + + B_l2=np.zeros((n,nbalpha)) + + B_wass=np.copy(B_l2) + + for i in range(0,nbalpha): + alpha=alphalist[i] + weights=np.array([1-alpha,alpha]) + B_l2[:,i]=A.dot(weights) + B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights) + + #%% plot interpolation + + pl.figure(3,(10,5)) + + #pl.subplot(1,2,1) + cmap=pl.cm.get_cmap('viridis') + verts = [] + zs = alphalist + for i,z in enumerate(zs): + ys = B_l2[:,i] + verts.append(list(zip(x, ys))) + + ax = pl.gcf().gca(projection='3d') + + poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) + poly.set_alpha(0.7) + ax.add_collection3d(poly, zs=zs, zdir='y') + + ax.set_xlabel('x') + ax.set_xlim3d(0, n) + ax.set_ylabel('$\\alpha$') + ax.set_ylim3d(0,1) + ax.set_zlabel('') + ax.set_zlim3d(0, B_l2.max()*1.01) + pl.title('Barycenter interpolation with l2') + + pl.show() + + pl.figure(4,(10,5)) + + #pl.subplot(1,2,1) + cmap=pl.cm.get_cmap('viridis') + verts = [] + zs = alphalist + for i,z in enumerate(zs): + ys = B_wass[:,i] + verts.append(list(zip(x, ys))) + + ax = pl.gcf().gca(projection='3d') + + poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) + poly.set_alpha(0.7) + ax.add_collection3d(poly, zs=zs, zdir='y') + + ax.set_xlabel('x') + ax.set_xlim3d(0, n) + ax.set_ylabel('$\\alpha$') + ax.set_ylim3d(0,1) + ax.set_zlabel('') + ax.set_zlim3d(0, B_l2.max()*1.01) + pl.title('Barycenter interpolation with Wasserstein') + + pl.show() +**Total running time of the script:** ( 0 minutes 2.274 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_barycenter_1D.py <plot_barycenter_1D.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_barycenter_1D.ipynb <plot_barycenter_1D.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/plot_optim_OTreg.ipynb b/docs/source/auto_examples/plot_optim_OTreg.ipynb new file mode 100644 index 0000000..250ea72 --- /dev/null +++ b/docs/source/auto_examples/plot_optim_OTreg.ipynb @@ -0,0 +1,54 @@ +{ + "nbformat_minor": 0, + "nbformat": 4, + "cells": [ + { + "execution_count": null, + "cell_type": "code", + "source": [ + "%matplotlib inline" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "\n# Regularized OT with generic solver\n\n\n\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n\n\n#%% parameters\n\nn=100 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\n# Gaussian distributions\na=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std\nb=ot.datasets.get_1D_gauss(n,m=60,s=10)\n\n# loss matrix\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)))\nM/=M.max()\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,G0,'OT matrix G0')\n\n#%% Example with Frobenius norm regularization\n\ndef f(G): return 0.5*np.sum(G**2)\ndef df(G): return G\n\nreg=1e-1\n\nGl2=ot.optim.cg(a,b,M,reg,f,df,verbose=True)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,Gl2,'OT matrix Frob. reg')\n\n#%% Example with entropic regularization\n\ndef f(G): return np.sum(G*np.log(G))\ndef df(G): return np.log(G)+1\n\nreg=1e-3\n\nGe=ot.optim.cg(a,b,M,reg,f,df,verbose=True)\n\npl.figure(4)\not.plot.plot1D_mat(a,b,Ge,'OT matrix Entrop. reg')\n\n#%% Example with Frobenius norm + entropic regularization with gcg\n\ndef f(G): return 0.5*np.sum(G**2)\ndef df(G): return G\n\nreg1=1e-1\nreg2=1e-1\n\nGel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True)\n\npl.figure(5)\not.plot.plot1D_mat(a,b,Gel2,'OT entropic + matrix Frob. reg')" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "name": "python2", + "language": "python" + }, + "language_info": { + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "name": "python", + "file_extension": ".py", + "version": "2.7.12", + "pygments_lexer": "ipython2", + "codemirror_mode": { + "version": 2, + "name": "ipython" + } + } + } +}
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_optim_OTreg.py b/docs/source/auto_examples/plot_optim_OTreg.py new file mode 100644 index 0000000..3c4d3f4 --- /dev/null +++ b/docs/source/auto_examples/plot_optim_OTreg.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +""" +================================== +Regularized OT with generic solver +================================== + + +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + + + +#%% parameters + +n=100 # nb bins + +# bin positions +x=np.arange(n,dtype=np.float64) + +# Gaussian distributions +a=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std +b=ot.datasets.get_1D_gauss(n,m=60,s=10) + +# loss matrix +M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) +M/=M.max() + +#%% EMD + +G0=ot.emd(a,b,M) + +pl.figure(3) +ot.plot.plot1D_mat(a,b,G0,'OT matrix G0') + +#%% Example with Frobenius norm regularization + +def f(G): return 0.5*np.sum(G**2) +def df(G): return G + +reg=1e-1 + +Gl2=ot.optim.cg(a,b,M,reg,f,df,verbose=True) + +pl.figure(3) +ot.plot.plot1D_mat(a,b,Gl2,'OT matrix Frob. reg') + +#%% Example with entropic regularization + +def f(G): return np.sum(G*np.log(G)) +def df(G): return np.log(G)+1 + +reg=1e-3 + +Ge=ot.optim.cg(a,b,M,reg,f,df,verbose=True) + +pl.figure(4) +ot.plot.plot1D_mat(a,b,Ge,'OT matrix Entrop. reg') + +#%% Example with Frobenius norm + entropic regularization with gcg + +def f(G): return 0.5*np.sum(G**2) +def df(G): return G + +reg1=1e-1 +reg2=1e-1 + +Gel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True) + +pl.figure(5) +ot.plot.plot1D_mat(a,b,Gel2,'OT entropic + matrix Frob. reg')
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_optim_OTreg.rst b/docs/source/auto_examples/plot_optim_OTreg.rst new file mode 100644 index 0000000..d6397ba --- /dev/null +++ b/docs/source/auto_examples/plot_optim_OTreg.rst @@ -0,0 +1,583 @@ + + +.. _sphx_glr_auto_examples_plot_optim_OTreg.py: + + +================================== +Regularized OT with generic solver +================================== + + + + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_optim_OTreg_003.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_optim_OTreg_004.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_optim_OTreg_005.png + :scale: 47 + + +.. rst-class:: sphx-glr-script-out + + Out:: + + It. |Loss |Delta loss + -------------------------------- + 0|1.760578e-01|0.000000e+00 + 1|1.669467e-01|-5.457501e-02 + 2|1.665639e-01|-2.298130e-03 + 3|1.664378e-01|-7.572776e-04 + 4|1.664077e-01|-1.811855e-04 + 5|1.663912e-01|-9.936787e-05 + 6|1.663852e-01|-3.555826e-05 + 7|1.663814e-01|-2.305693e-05 + 8|1.663785e-01|-1.760450e-05 + 9|1.663767e-01|-1.078011e-05 + 10|1.663751e-01|-9.525192e-06 + 11|1.663737e-01|-8.396466e-06 + 12|1.663727e-01|-6.086938e-06 + 13|1.663720e-01|-4.042609e-06 + 14|1.663713e-01|-4.160914e-06 + 15|1.663707e-01|-3.823502e-06 + 16|1.663702e-01|-3.022440e-06 + 17|1.663697e-01|-3.181249e-06 + 18|1.663692e-01|-2.698532e-06 + 19|1.663687e-01|-3.258253e-06 + It. |Loss |Delta loss + -------------------------------- + 20|1.663682e-01|-2.741118e-06 + 21|1.663678e-01|-2.624135e-06 + 22|1.663673e-01|-2.645179e-06 + 23|1.663670e-01|-1.957237e-06 + 24|1.663666e-01|-2.261541e-06 + 25|1.663663e-01|-1.851305e-06 + 26|1.663660e-01|-1.942296e-06 + 27|1.663657e-01|-2.092896e-06 + 28|1.663653e-01|-1.924361e-06 + 29|1.663651e-01|-1.625455e-06 + 30|1.663648e-01|-1.641123e-06 + 31|1.663645e-01|-1.566666e-06 + 32|1.663643e-01|-1.338514e-06 + 33|1.663641e-01|-1.222711e-06 + 34|1.663639e-01|-1.221805e-06 + 35|1.663637e-01|-1.440781e-06 + 36|1.663634e-01|-1.520091e-06 + 37|1.663632e-01|-1.288193e-06 + 38|1.663630e-01|-1.123055e-06 + 39|1.663628e-01|-1.024487e-06 + It. |Loss |Delta loss + -------------------------------- + 40|1.663627e-01|-1.079606e-06 + 41|1.663625e-01|-1.172093e-06 + 42|1.663623e-01|-1.047880e-06 + 43|1.663621e-01|-1.010577e-06 + 44|1.663619e-01|-1.064438e-06 + 45|1.663618e-01|-9.882375e-07 + 46|1.663616e-01|-8.532647e-07 + 47|1.663615e-01|-9.930189e-07 + 48|1.663613e-01|-8.728955e-07 + 49|1.663612e-01|-9.524214e-07 + 50|1.663610e-01|-9.088418e-07 + 51|1.663609e-01|-7.639430e-07 + 52|1.663608e-01|-6.662611e-07 + 53|1.663607e-01|-7.133700e-07 + 54|1.663605e-01|-7.648141e-07 + 55|1.663604e-01|-6.557516e-07 + 56|1.663603e-01|-7.304213e-07 + 57|1.663602e-01|-6.353809e-07 + 58|1.663601e-01|-7.968279e-07 + 59|1.663600e-01|-6.367159e-07 + It. |Loss |Delta loss + -------------------------------- + 60|1.663599e-01|-5.610790e-07 + 61|1.663598e-01|-5.787466e-07 + 62|1.663596e-01|-6.937777e-07 + 63|1.663596e-01|-5.599432e-07 + 64|1.663595e-01|-5.813048e-07 + 65|1.663594e-01|-5.724600e-07 + 66|1.663593e-01|-6.081892e-07 + 67|1.663592e-01|-5.948732e-07 + 68|1.663591e-01|-4.941833e-07 + 69|1.663590e-01|-5.213739e-07 + 70|1.663589e-01|-5.127355e-07 + 71|1.663588e-01|-4.349251e-07 + 72|1.663588e-01|-5.007084e-07 + 73|1.663587e-01|-4.880265e-07 + 74|1.663586e-01|-4.931950e-07 + 75|1.663585e-01|-4.981309e-07 + 76|1.663584e-01|-3.952959e-07 + 77|1.663584e-01|-4.544857e-07 + 78|1.663583e-01|-4.237579e-07 + 79|1.663582e-01|-4.382386e-07 + It. |Loss |Delta loss + -------------------------------- + 80|1.663582e-01|-3.646051e-07 + 81|1.663581e-01|-4.197994e-07 + 82|1.663580e-01|-4.072764e-07 + 83|1.663580e-01|-3.994645e-07 + 84|1.663579e-01|-4.842721e-07 + 85|1.663578e-01|-3.276486e-07 + 86|1.663578e-01|-3.737346e-07 + 87|1.663577e-01|-4.282043e-07 + 88|1.663576e-01|-4.020937e-07 + 89|1.663576e-01|-3.431951e-07 + 90|1.663575e-01|-3.052335e-07 + 91|1.663575e-01|-3.500538e-07 + 92|1.663574e-01|-3.063176e-07 + 93|1.663573e-01|-3.576367e-07 + 94|1.663573e-01|-3.224681e-07 + 95|1.663572e-01|-3.673221e-07 + 96|1.663572e-01|-3.635561e-07 + 97|1.663571e-01|-3.527236e-07 + 98|1.663571e-01|-2.788548e-07 + 99|1.663570e-01|-2.727141e-07 + It. |Loss |Delta loss + -------------------------------- + 100|1.663570e-01|-3.127278e-07 + 101|1.663569e-01|-2.637504e-07 + 102|1.663569e-01|-2.922750e-07 + 103|1.663568e-01|-3.076454e-07 + 104|1.663568e-01|-2.911509e-07 + 105|1.663567e-01|-2.403398e-07 + 106|1.663567e-01|-2.439790e-07 + 107|1.663567e-01|-2.634542e-07 + 108|1.663566e-01|-2.452203e-07 + 109|1.663566e-01|-2.852991e-07 + 110|1.663565e-01|-2.165490e-07 + 111|1.663565e-01|-2.450250e-07 + 112|1.663564e-01|-2.685294e-07 + 113|1.663564e-01|-2.821800e-07 + 114|1.663564e-01|-2.237390e-07 + 115|1.663563e-01|-1.992842e-07 + 116|1.663563e-01|-2.166739e-07 + 117|1.663563e-01|-2.086064e-07 + 118|1.663562e-01|-2.435945e-07 + 119|1.663562e-01|-2.292497e-07 + It. |Loss |Delta loss + -------------------------------- + 120|1.663561e-01|-2.366209e-07 + 121|1.663561e-01|-2.138746e-07 + 122|1.663561e-01|-2.009637e-07 + 123|1.663560e-01|-2.386258e-07 + 124|1.663560e-01|-1.927442e-07 + 125|1.663560e-01|-2.081681e-07 + 126|1.663559e-01|-1.759123e-07 + 127|1.663559e-01|-1.890771e-07 + 128|1.663559e-01|-1.971315e-07 + 129|1.663558e-01|-2.101983e-07 + 130|1.663558e-01|-2.035645e-07 + 131|1.663558e-01|-1.984492e-07 + 132|1.663557e-01|-1.849064e-07 + 133|1.663557e-01|-1.795703e-07 + 134|1.663557e-01|-1.624087e-07 + 135|1.663557e-01|-1.689557e-07 + 136|1.663556e-01|-1.644308e-07 + 137|1.663556e-01|-1.618007e-07 + 138|1.663556e-01|-1.483013e-07 + 139|1.663555e-01|-1.708771e-07 + It. |Loss |Delta loss + -------------------------------- + 140|1.663555e-01|-2.013847e-07 + 141|1.663555e-01|-1.721217e-07 + 142|1.663554e-01|-2.027911e-07 + 143|1.663554e-01|-1.764565e-07 + 144|1.663554e-01|-1.677151e-07 + 145|1.663554e-01|-1.351982e-07 + 146|1.663553e-01|-1.423360e-07 + 147|1.663553e-01|-1.541112e-07 + 148|1.663553e-01|-1.491601e-07 + 149|1.663553e-01|-1.466407e-07 + 150|1.663552e-01|-1.801524e-07 + 151|1.663552e-01|-1.714107e-07 + 152|1.663552e-01|-1.491257e-07 + 153|1.663552e-01|-1.513799e-07 + 154|1.663551e-01|-1.354539e-07 + 155|1.663551e-01|-1.233818e-07 + 156|1.663551e-01|-1.576219e-07 + 157|1.663551e-01|-1.452791e-07 + 158|1.663550e-01|-1.262867e-07 + 159|1.663550e-01|-1.316379e-07 + It. |Loss |Delta loss + -------------------------------- + 160|1.663550e-01|-1.295447e-07 + 161|1.663550e-01|-1.283286e-07 + 162|1.663550e-01|-1.569222e-07 + 163|1.663549e-01|-1.172942e-07 + 164|1.663549e-01|-1.399809e-07 + 165|1.663549e-01|-1.229432e-07 + 166|1.663549e-01|-1.326191e-07 + 167|1.663548e-01|-1.209694e-07 + 168|1.663548e-01|-1.372136e-07 + 169|1.663548e-01|-1.338395e-07 + 170|1.663548e-01|-1.416497e-07 + 171|1.663548e-01|-1.298576e-07 + 172|1.663547e-01|-1.190590e-07 + 173|1.663547e-01|-1.167083e-07 + 174|1.663547e-01|-1.069425e-07 + 175|1.663547e-01|-1.217780e-07 + 176|1.663547e-01|-1.140754e-07 + 177|1.663546e-01|-1.160707e-07 + 178|1.663546e-01|-1.101798e-07 + 179|1.663546e-01|-1.114904e-07 + It. |Loss |Delta loss + -------------------------------- + 180|1.663546e-01|-1.064022e-07 + 181|1.663546e-01|-9.258231e-08 + 182|1.663546e-01|-1.213120e-07 + 183|1.663545e-01|-1.164296e-07 + 184|1.663545e-01|-1.188762e-07 + 185|1.663545e-01|-9.394153e-08 + 186|1.663545e-01|-1.028656e-07 + 187|1.663545e-01|-1.115348e-07 + 188|1.663544e-01|-9.768310e-08 + 189|1.663544e-01|-1.021806e-07 + 190|1.663544e-01|-1.086303e-07 + 191|1.663544e-01|-9.879008e-08 + 192|1.663544e-01|-1.050210e-07 + 193|1.663544e-01|-1.002463e-07 + 194|1.663543e-01|-1.062747e-07 + 195|1.663543e-01|-9.348538e-08 + 196|1.663543e-01|-7.992512e-08 + 197|1.663543e-01|-9.558020e-08 + 198|1.663543e-01|-9.993772e-08 + 199|1.663543e-01|-8.588499e-08 + It. |Loss |Delta loss + -------------------------------- + 200|1.663543e-01|-8.737134e-08 + It. |Loss |Delta loss + -------------------------------- + 0|1.692289e-01|0.000000e+00 + 1|1.617643e-01|-4.614437e-02 + 2|1.612546e-01|-3.161037e-03 + 3|1.611040e-01|-9.349544e-04 + 4|1.610346e-01|-4.310179e-04 + 5|1.610072e-01|-1.701719e-04 + 6|1.609947e-01|-7.759814e-05 + 7|1.609934e-01|-7.941439e-06 + 8|1.609841e-01|-5.797180e-05 + 9|1.609838e-01|-1.559407e-06 + 10|1.609685e-01|-9.530282e-05 + 11|1.609666e-01|-1.142129e-05 + 12|1.609541e-01|-7.799970e-05 + 13|1.609496e-01|-2.780416e-05 + 14|1.609385e-01|-6.887105e-05 + 15|1.609334e-01|-3.174241e-05 + 16|1.609231e-01|-6.420777e-05 + 17|1.609115e-01|-7.189949e-05 + 18|1.608815e-01|-1.865331e-04 + 19|1.608799e-01|-1.013039e-05 + It. |Loss |Delta loss + -------------------------------- + 20|1.608695e-01|-6.468606e-05 + 21|1.608686e-01|-5.738419e-06 + 22|1.608661e-01|-1.495923e-05 + 23|1.608657e-01|-2.784611e-06 + 24|1.608633e-01|-1.512408e-05 + 25|1.608624e-01|-5.397916e-06 + 26|1.608617e-01|-4.115218e-06 + 27|1.608561e-01|-3.503396e-05 + 28|1.608479e-01|-5.098773e-05 + 29|1.608452e-01|-1.659203e-05 + 30|1.608399e-01|-3.298319e-05 + 31|1.608330e-01|-4.302183e-05 + 32|1.608310e-01|-1.273465e-05 + 33|1.608280e-01|-1.827713e-05 + 34|1.608231e-01|-3.039842e-05 + 35|1.608212e-01|-1.229256e-05 + 36|1.608200e-01|-6.900556e-06 + 37|1.608159e-01|-2.554039e-05 + 38|1.608103e-01|-3.521137e-05 + 39|1.608058e-01|-2.795180e-05 + It. |Loss |Delta loss + -------------------------------- + 40|1.608040e-01|-1.119118e-05 + 41|1.608027e-01|-8.193369e-06 + 42|1.607994e-01|-2.026719e-05 + 43|1.607985e-01|-5.819902e-06 + 44|1.607978e-01|-4.048170e-06 + 45|1.607978e-01|-3.007470e-07 + 46|1.607950e-01|-1.705375e-05 + 47|1.607927e-01|-1.430186e-05 + 48|1.607925e-01|-1.166526e-06 + 49|1.607911e-01|-9.069406e-06 + 50|1.607910e-01|-3.804209e-07 + 51|1.607910e-01|-5.942399e-08 + 52|1.607910e-01|-2.321380e-07 + 53|1.607907e-01|-1.877655e-06 + 54|1.607906e-01|-2.940224e-07 + 55|1.607877e-01|-1.814208e-05 + 56|1.607841e-01|-2.236496e-05 + 57|1.607810e-01|-1.951355e-05 + 58|1.607804e-01|-3.578228e-06 + 59|1.607789e-01|-9.442277e-06 + It. |Loss |Delta loss + -------------------------------- + 60|1.607779e-01|-5.997371e-06 + 61|1.607754e-01|-1.564408e-05 + 62|1.607742e-01|-7.693285e-06 + 63|1.607727e-01|-9.030547e-06 + 64|1.607719e-01|-5.103894e-06 + 65|1.607693e-01|-1.605420e-05 + 66|1.607676e-01|-1.047837e-05 + 67|1.607675e-01|-6.026848e-07 + 68|1.607655e-01|-1.240216e-05 + 69|1.607632e-01|-1.434674e-05 + 70|1.607618e-01|-8.829808e-06 + 71|1.607606e-01|-7.581824e-06 + 72|1.607590e-01|-1.009457e-05 + 73|1.607586e-01|-2.222963e-06 + 74|1.607577e-01|-5.564775e-06 + 75|1.607574e-01|-1.932763e-06 + 76|1.607573e-01|-8.148685e-07 + 77|1.607554e-01|-1.187660e-05 + 78|1.607546e-01|-4.557651e-06 + 79|1.607537e-01|-5.911902e-06 + It. |Loss |Delta loss + -------------------------------- + 80|1.607529e-01|-4.710187e-06 + 81|1.607528e-01|-8.866080e-07 + 82|1.607522e-01|-3.620627e-06 + 83|1.607514e-01|-5.091281e-06 + 84|1.607498e-01|-9.932095e-06 + 85|1.607487e-01|-6.852804e-06 + 86|1.607478e-01|-5.373596e-06 + 87|1.607473e-01|-3.287295e-06 + 88|1.607470e-01|-1.666655e-06 + 89|1.607469e-01|-5.293790e-07 + 90|1.607466e-01|-2.051914e-06 + 91|1.607456e-01|-6.422797e-06 + 92|1.607456e-01|-1.110433e-07 + 93|1.607451e-01|-2.803849e-06 + 94|1.607451e-01|-2.608066e-07 + 95|1.607441e-01|-6.290352e-06 + 96|1.607429e-01|-7.298455e-06 + 97|1.607429e-01|-8.969905e-09 + 98|1.607427e-01|-7.923968e-07 + 99|1.607427e-01|-3.519286e-07 + It. |Loss |Delta loss + -------------------------------- + 100|1.607426e-01|-3.563804e-07 + 101|1.607410e-01|-1.004042e-05 + 102|1.607410e-01|-2.124801e-07 + 103|1.607398e-01|-7.556935e-06 + 104|1.607398e-01|-7.606853e-08 + 105|1.607385e-01|-8.058684e-06 + 106|1.607383e-01|-7.393061e-07 + 107|1.607381e-01|-1.504958e-06 + 108|1.607377e-01|-2.508807e-06 + 109|1.607371e-01|-4.004631e-06 + 110|1.607365e-01|-3.580156e-06 + 111|1.607364e-01|-2.563573e-07 + 112|1.607354e-01|-6.390137e-06 + 113|1.607348e-01|-4.119553e-06 + 114|1.607339e-01|-5.299475e-06 + 115|1.607335e-01|-2.316767e-06 + 116|1.607330e-01|-3.444737e-06 + 117|1.607324e-01|-3.467980e-06 + 118|1.607320e-01|-2.374632e-06 + 119|1.607319e-01|-7.978255e-07 + It. |Loss |Delta loss + -------------------------------- + 120|1.607312e-01|-4.221434e-06 + 121|1.607310e-01|-1.324597e-06 + 122|1.607304e-01|-3.650359e-06 + 123|1.607298e-01|-3.732712e-06 + 124|1.607295e-01|-1.994082e-06 + 125|1.607289e-01|-3.954139e-06 + 126|1.607286e-01|-1.532372e-06 + 127|1.607286e-01|-1.167223e-07 + 128|1.607283e-01|-2.157376e-06 + 129|1.607279e-01|-2.253077e-06 + 130|1.607274e-01|-3.301532e-06 + 131|1.607269e-01|-2.650754e-06 + 132|1.607264e-01|-3.595551e-06 + 133|1.607262e-01|-1.159425e-06 + 134|1.607258e-01|-2.512411e-06 + 135|1.607255e-01|-1.998792e-06 + 136|1.607251e-01|-2.486536e-06 + 137|1.607246e-01|-2.782996e-06 + 138|1.607246e-01|-2.922470e-07 + 139|1.607242e-01|-2.071131e-06 + It. |Loss |Delta loss + -------------------------------- + 140|1.607237e-01|-3.154193e-06 + 141|1.607235e-01|-1.194962e-06 + 142|1.607232e-01|-2.035251e-06 + 143|1.607232e-01|-6.027855e-08 + 144|1.607229e-01|-1.555696e-06 + 145|1.607228e-01|-1.081740e-06 + 146|1.607225e-01|-1.881070e-06 + 147|1.607224e-01|-4.100096e-07 + 148|1.607223e-01|-7.785200e-07 + 149|1.607222e-01|-2.094072e-07 + 150|1.607220e-01|-1.440814e-06 + 151|1.607217e-01|-1.997794e-06 + 152|1.607214e-01|-2.011022e-06 + 153|1.607212e-01|-8.808854e-07 + 154|1.607211e-01|-7.245877e-07 + 155|1.607207e-01|-2.217159e-06 + 156|1.607201e-01|-3.817891e-06 + 157|1.607200e-01|-7.409600e-07 + 158|1.607198e-01|-1.497698e-06 + 159|1.607195e-01|-1.729666e-06 + It. |Loss |Delta loss + -------------------------------- + 160|1.607195e-01|-2.115187e-07 + 161|1.607192e-01|-1.643727e-06 + 162|1.607192e-01|-1.712969e-07 + 163|1.607189e-01|-1.805877e-06 + 164|1.607189e-01|-1.209827e-07 + 165|1.607185e-01|-2.060002e-06 + 166|1.607182e-01|-1.961341e-06 + 167|1.607181e-01|-1.020366e-06 + 168|1.607179e-01|-9.760982e-07 + 169|1.607178e-01|-7.219236e-07 + 170|1.607175e-01|-1.837718e-06 + 171|1.607174e-01|-3.337578e-07 + 172|1.607173e-01|-5.298564e-07 + 173|1.607173e-01|-6.864278e-08 + 174|1.607173e-01|-2.008419e-07 + 175|1.607171e-01|-1.375630e-06 + 176|1.607168e-01|-1.911257e-06 + 177|1.607167e-01|-2.709815e-07 + 178|1.607167e-01|-1.390953e-07 + 179|1.607165e-01|-1.199675e-06 + It. |Loss |Delta loss + -------------------------------- + 180|1.607165e-01|-1.457259e-07 + 181|1.607163e-01|-1.049154e-06 + 182|1.607163e-01|-2.753577e-09 + 183|1.607163e-01|-6.972814e-09 + 184|1.607161e-01|-1.552100e-06 + 185|1.607159e-01|-1.068596e-06 + 186|1.607157e-01|-1.247724e-06 + 187|1.607155e-01|-1.158164e-06 + 188|1.607155e-01|-2.616199e-07 + 189|1.607154e-01|-3.595874e-07 + 190|1.607154e-01|-5.334527e-08 + 191|1.607153e-01|-3.452744e-07 + 192|1.607153e-01|-1.239593e-07 + 193|1.607152e-01|-8.184984e-07 + 194|1.607150e-01|-1.316308e-06 + 195|1.607150e-01|-7.100882e-09 + 196|1.607148e-01|-1.393958e-06 + 197|1.607146e-01|-1.242735e-06 + 198|1.607144e-01|-1.123993e-06 + 199|1.607143e-01|-3.512071e-07 + It. |Loss |Delta loss + -------------------------------- + 200|1.607143e-01|-2.151971e-10 + It. |Loss |Delta loss + -------------------------------- + 0|-4.988764e-01|0.000000e+00 + 1|-4.993932e-01|-1.034993e-03 + 2|-4.993933e-01|-9.845917e-08 + 3|-4.993933e-01|-9.206594e-12 + + + + +| + + +.. code-block:: python + + + import numpy as np + import matplotlib.pylab as pl + import ot + + + + #%% parameters + + n=100 # nb bins + + # bin positions + x=np.arange(n,dtype=np.float64) + + # Gaussian distributions + a=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std + b=ot.datasets.get_1D_gauss(n,m=60,s=10) + + # loss matrix + M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) + M/=M.max() + + #%% EMD + + G0=ot.emd(a,b,M) + + pl.figure(3) + ot.plot.plot1D_mat(a,b,G0,'OT matrix G0') + + #%% Example with Frobenius norm regularization + + def f(G): return 0.5*np.sum(G**2) + def df(G): return G + + reg=1e-1 + + Gl2=ot.optim.cg(a,b,M,reg,f,df,verbose=True) + + pl.figure(3) + ot.plot.plot1D_mat(a,b,Gl2,'OT matrix Frob. reg') + + #%% Example with entropic regularization + + def f(G): return np.sum(G*np.log(G)) + def df(G): return np.log(G)+1 + + reg=1e-3 + + Ge=ot.optim.cg(a,b,M,reg,f,df,verbose=True) + + pl.figure(4) + ot.plot.plot1D_mat(a,b,Ge,'OT matrix Entrop. reg') + + #%% Example with Frobenius norm + entropic regularization with gcg + + def f(G): return 0.5*np.sum(G**2) + def df(G): return G + + reg1=1e-1 + reg2=1e-1 + + Gel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True) + + pl.figure(5) + ot.plot.plot1D_mat(a,b,Gel2,'OT entropic + matrix Frob. reg') +**Total running time of the script:** ( 0 minutes 2.358 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_optim_OTreg.py <plot_optim_OTreg.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_optim_OTreg.ipynb <plot_optim_OTreg.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/conf.py b/docs/source/conf.py index f76c184..1a12639 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -58,7 +58,7 @@ extensions = [ 'sphinx.ext.ifconfig', 'sphinx.ext.viewcode', 'sphinx.ext.napoleon', - 'sphinx_gallery.gen_gallery', +# 'sphinx_gallery.gen_gallery', ] # Add any paths that contain templates here, relative to this directory. |