diff options
Diffstat (limited to 'docs')
52 files changed, 1231 insertions, 62 deletions
diff --git a/docs/source/auto_examples/auto_examples_jupyter.zip b/docs/source/auto_examples/auto_examples_jupyter.zip Binary files differindex 2b892cf..7c3de28 100644 --- a/docs/source/auto_examples/auto_examples_jupyter.zip +++ b/docs/source/auto_examples/auto_examples_jupyter.zip diff --git a/docs/source/auto_examples/auto_examples_python.zip b/docs/source/auto_examples/auto_examples_python.zip Binary files differindex 9799061..97377e1 100644 --- a/docs/source/auto_examples/auto_examples_python.zip +++ b/docs/source/auto_examples/auto_examples_python.zip diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.png Binary files differindex 9b0c3f5..da42bc1 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.png Binary files differindex f2cf4f7..1f98598 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_003.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_003.png Binary files differindex a4252cf..9e893d6 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_003.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_003.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_004.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_004.png Binary files differindex 9bddc80..3bc248b 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_004.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_004.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png Binary files differindex c31761d..e023ab4 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png Binary files differindex f6c5d0e..dda21d4 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_003.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_003.png Binary files differindex 8f5dfb9..f0967fb 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_003.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_003.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_004.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_004.png Binary files differindex 309565f..809c8fc 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_004.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_004.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png Binary files differindex 6ba8a56..887bdde 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png Binary files differindex 7fa0f98..783c594 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_001.png Binary files differnew file mode 100644 index 0000000..b159a6a --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.png Binary files differnew file mode 100644 index 0000000..9f8e882 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_003.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_003.png Binary files differnew file mode 100644 index 0000000..33058fc --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_003.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_004.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_004.png Binary files differnew file mode 100644 index 0000000..9848bcb --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_004.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.png Binary files differnew file mode 100644 index 0000000..6616d3c --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_006.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_006.png Binary files differnew file mode 100644 index 0000000..8575d93 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_006.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_WDA_001.png b/docs/source/auto_examples/images/sphx_glr_plot_WDA_001.png Binary files differindex c577317..250155d 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_WDA_001.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_WDA_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_001.png b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_001.png Binary files differnew file mode 100644 index 0000000..4917903 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_002.png b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_002.png Binary files differnew file mode 100644 index 0000000..7c06255 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_002.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png Binary files differindex 9d0ec13..15c9825 100644 --- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_sinkhornlp_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_sinkhornlp_thumb.png Binary files differnew file mode 100644 index 0000000..3015582 --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_sinkhornlp_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png Binary files differindex d64d50b..bac78f0 100644 --- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.png Binary files differnew file mode 100644 index 0000000..c67e8aa --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_conv_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_conv_thumb.png Binary files differnew file mode 100644 index 0000000..3015582 --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_conv_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.png Binary files differindex 2597600..84759e8 100644 --- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.png +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_compute_emd_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_compute_emd_thumb.png Binary files differnew file mode 100644 index 0000000..67d2ca1 --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_compute_emd_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_test_OT_2D_samples_stabilized_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_test_OT_2D_samples_stabilized_thumb.png Binary files differnew file mode 100644 index 0000000..cbc8e0f --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_test_OT_2D_samples_stabilized_thumb.png diff --git a/docs/source/auto_examples/index.rst b/docs/source/auto_examples/index.rst index 868ca76..1695300 100644 --- a/docs/source/auto_examples/index.rst +++ b/docs/source/auto_examples/index.rst @@ -83,6 +83,26 @@ POT Examples .. raw:: html + <div class="sphx-glr-thumbcontainer" tooltip="@author: rflamary "> + +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_compute_emd_thumb.png + + :ref:`sphx_glr_auto_examples_plot_compute_emd.py` + +.. raw:: html + + </div> + + +.. toctree:: + :hidden: + + /auto_examples/plot_compute_emd + +.. raw:: html + <div class="sphx-glr-thumbcontainer" tooltip="[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized discrete optima..."> .. only:: html @@ -143,6 +163,26 @@ POT Examples .. raw:: html + <div class="sphx-glr-thumbcontainer" tooltip="Stole the figure idea from Fig. 1 and 2 in https://arxiv.org/pdf/1706.07650.pdf"> + +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.png + + :ref:`sphx_glr_auto_examples_plot_OT_L1_vs_L2.py` + +.. raw:: html + + </div> + + +.. toctree:: + :hidden: + + /auto_examples/plot_OT_L1_vs_L2 + +.. raw:: html + <div class="sphx-glr-thumbcontainer" tooltip=" @author: rflamary "> .. only:: html diff --git a/docs/source/auto_examples/plot_OT_1D.ipynb b/docs/source/auto_examples/plot_OT_1D.ipynb index 17d0b21..8715b97 100644 --- a/docs/source/auto_examples/plot_OT_1D.ipynb +++ b/docs/source/auto_examples/plot_OT_1D.ipynb @@ -24,7 +24,7 @@ "execution_count": null, "cell_type": "code", "source": [ - "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss\n\n\n#%% parameters\n\nn=100 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\n# Gaussian distributions\na=gauss(n,m=20,s=5) # m= mean, s= std\nb=gauss(n,m=60,s=10)\n\n# loss matrix\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)))\nM/=M.max()\n\n#%% plot the distributions\n\npl.figure(1)\npl.plot(x,a,'b',label='Source distribution')\npl.plot(x,b,'r',label='Target distribution')\npl.legend()\n\n#%% plot distributions and loss matrix\n\npl.figure(2)\not.plot.plot1D_mat(a,b,M,'Cost matrix M')\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,G0,'OT matrix G0')\n\n#%% Sinkhorn\n\nlambd=1e-3\nGs=ot.sinkhorn(a,b,M,lambd)\n\npl.figure(4)\not.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')" + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss\n\n\n#%% parameters\n\nn=100 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\n# Gaussian distributions\na=gauss(n,m=20,s=5) # m= mean, s= std\nb=gauss(n,m=60,s=10)\n\n# loss matrix\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)))\nM/=M.max()\n\n#%% plot the distributions\n\npl.figure(1)\npl.plot(x,a,'b',label='Source distribution')\npl.plot(x,b,'r',label='Target distribution')\npl.legend()\n\n#%% plot distributions and loss matrix\n\npl.figure(2)\not.plot.plot1D_mat(a,b,M,'Cost matrix M')\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,G0,'OT matrix G0')\n\n#%% Sinkhorn\n\nlambd=1e-3\nGs=ot.sinkhorn(a,b,M,lambd,verbose=True)\n\npl.figure(4)\not.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')" ], "outputs": [], "metadata": { diff --git a/docs/source/auto_examples/plot_OT_1D.py b/docs/source/auto_examples/plot_OT_1D.py index e5719eb..6661aa3 100644 --- a/docs/source/auto_examples/plot_OT_1D.py +++ b/docs/source/auto_examples/plot_OT_1D.py @@ -50,7 +50,7 @@ ot.plot.plot1D_mat(a,b,G0,'OT matrix G0') #%% Sinkhorn lambd=1e-3 -Gs=ot.sinkhorn(a,b,M,lambd) +Gs=ot.sinkhorn(a,b,M,lambd,verbose=True) pl.figure(4) ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn') diff --git a/docs/source/auto_examples/plot_OT_1D.rst b/docs/source/auto_examples/plot_OT_1D.rst index 941fd54..44b715b 100644 --- a/docs/source/auto_examples/plot_OT_1D.rst +++ b/docs/source/auto_examples/plot_OT_1D.rst @@ -36,7 +36,29 @@ :scale: 47 +.. rst-class:: sphx-glr-script-out + Out:: + + It. |Err + ------------------- + 0|8.187970e-02| + 10|3.460174e-02| + 20|6.633335e-03| + 30|9.797798e-04| + 40|1.389606e-04| + 50|1.959016e-05| + 60|2.759079e-06| + 70|3.885166e-07| + 80|5.470605e-08| + 90|7.702918e-09| + 100|1.084609e-09| + 110|1.527180e-10| + + + + +| .. code-block:: python @@ -85,12 +107,12 @@ #%% Sinkhorn lambd=1e-3 - Gs=ot.sinkhorn(a,b,M,lambd) + Gs=ot.sinkhorn(a,b,M,lambd,verbose=True) pl.figure(4) ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn') -**Total running time of the script:** ( 0 minutes 0.597 seconds) +**Total running time of the script:** ( 0 minutes 0.674 seconds) diff --git a/docs/source/auto_examples/plot_OT_2D_samples.ipynb b/docs/source/auto_examples/plot_OT_2D_samples.ipynb index 7d42ba7..fad0467 100644 --- a/docs/source/auto_examples/plot_OT_2D_samples.ipynb +++ b/docs/source/auto_examples/plot_OT_2D_samples.ipynb @@ -24,7 +24,7 @@ "execution_count": null, "cell_type": "code", "source": [ - "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n#%% parameters and data generation\n\nn=2 # nb samples\n\nmu_s=np.array([0,0])\ncov_s=np.array([[1,0],[0,1]])\n\nmu_t=np.array([4,4])\ncov_t=np.array([[1,-.8],[-.8,1]])\n\nxs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)\nxt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)\n\na,b = ot.unif(n),ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM=ot.dist(xs,xt)\nM/=M.max()\n\n#%% plot samples\n\npl.figure(1)\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('Source and traget distributions')\n\npl.figure(2)\npl.imshow(M,interpolation='nearest')\npl.title('Cost matrix M')\n\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\npl.imshow(G0,interpolation='nearest')\npl.title('OT matrix G0')\n\npl.figure(4)\not.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix with samples')\n\n\n#%% sinkhorn\n\n# reg term\nlambd=5e-3\n\nGs=ot.sinkhorn(a,b,M,lambd)\n\npl.figure(5)\npl.imshow(Gs,interpolation='nearest')\npl.title('OT matrix sinkhorn')\n\npl.figure(6)\not.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn with samples')" + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n#%% parameters and data generation\n\nn=50 # nb samples\n\nmu_s=np.array([0,0])\ncov_s=np.array([[1,0],[0,1]])\n\nmu_t=np.array([4,4])\ncov_t=np.array([[1,-.8],[-.8,1]])\n\nxs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)\nxt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)\n\na,b = ot.unif(n),ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM=ot.dist(xs,xt)\nM/=M.max()\n\n#%% plot samples\n\npl.figure(1)\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('Source and traget distributions')\n\npl.figure(2)\npl.imshow(M,interpolation='nearest')\npl.title('Cost matrix M')\n\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\npl.imshow(G0,interpolation='nearest')\npl.title('OT matrix G0')\n\npl.figure(4)\not.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix with samples')\n\n\n#%% sinkhorn\n\n# reg term\nlambd=5e-4\n\nGs=ot.sinkhorn(a,b,M,lambd)\n\npl.figure(5)\npl.imshow(Gs,interpolation='nearest')\npl.title('OT matrix sinkhorn')\n\npl.figure(6)\not.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn with samples')" ], "outputs": [], "metadata": { diff --git a/docs/source/auto_examples/plot_OT_2D_samples.py b/docs/source/auto_examples/plot_OT_2D_samples.py index 3b95083..edfb781 100644 --- a/docs/source/auto_examples/plot_OT_2D_samples.py +++ b/docs/source/auto_examples/plot_OT_2D_samples.py @@ -13,7 +13,7 @@ import ot #%% parameters and data generation -n=2 # nb samples +n=50 # nb samples mu_s=np.array([0,0]) cov_s=np.array([[1,0],[0,1]]) @@ -62,7 +62,7 @@ pl.title('OT matrix with samples') #%% sinkhorn # reg term -lambd=5e-3 +lambd=5e-4 Gs=ot.sinkhorn(a,b,M,lambd) diff --git a/docs/source/auto_examples/plot_OT_2D_samples.rst b/docs/source/auto_examples/plot_OT_2D_samples.rst index 01e5f31..e05e591 100644 --- a/docs/source/auto_examples/plot_OT_2D_samples.rst +++ b/docs/source/auto_examples/plot_OT_2D_samples.rst @@ -46,7 +46,16 @@ :scale: 47 +.. rst-class:: sphx-glr-script-out + Out:: + + ('Warning: numerical errors at iteration', 0) + + + + +| .. code-block:: python @@ -58,7 +67,7 @@ #%% parameters and data generation - n=2 # nb samples + n=50 # nb samples mu_s=np.array([0,0]) cov_s=np.array([[1,0],[0,1]]) @@ -107,7 +116,7 @@ #%% sinkhorn # reg term - lambd=5e-3 + lambd=5e-4 Gs=ot.sinkhorn(a,b,M,lambd) @@ -122,7 +131,7 @@ pl.legend(loc=0) pl.title('OT matrix Sinkhorn with samples') -**Total running time of the script:** ( 0 minutes 0.406 seconds) +**Total running time of the script:** ( 0 minutes 0.623 seconds) diff --git a/docs/source/auto_examples/plot_OT_L1_vs_L2.ipynb b/docs/source/auto_examples/plot_OT_L1_vs_L2.ipynb new file mode 100644 index 0000000..46283ac --- /dev/null +++ b/docs/source/auto_examples/plot_OT_L1_vs_L2.ipynb @@ -0,0 +1,54 @@ +{ + "nbformat_minor": 0, + "nbformat": 4, + "cells": [ + { + "execution_count": null, + "cell_type": "code", + "source": [ + "%matplotlib inline" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "\n# 2D Optimal transport for different metrics\n\n\nStole the figure idea from Fig. 1 and 2 in \nhttps://arxiv.org/pdf/1706.07650.pdf\n\n\n@author: rflamary\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n#%% parameters and data generation\n\nfor data in range(2):\n\n if data:\n n=20 # nb samples\n xs=np.zeros((n,2))\n xs[:,0]=np.arange(n)+1\n xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex...\n \n xt=np.zeros((n,2))\n xt[:,1]=np.arange(n)+1\n else:\n \n n=50 # nb samples\n xtot=np.zeros((n+1,2))\n xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)\n xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)\n \n xs=xtot[:n,:]\n xt=xtot[1:,:]\n \n \n \n a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples\n \n # loss matrix\n M1=ot.dist(xs,xt,metric='euclidean')\n M1/=M1.max()\n \n # loss matrix\n M2=ot.dist(xs,xt,metric='sqeuclidean')\n M2/=M2.max()\n \n # loss matrix\n Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean'))\n Mp/=Mp.max()\n \n #%% plot samples\n \n pl.figure(1+3*data)\n pl.clf()\n pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n pl.axis('equal')\n pl.title('Source and traget distributions')\n \n pl.figure(2+3*data,(15,5))\n pl.subplot(1,3,1)\n pl.imshow(M1,interpolation='nearest')\n pl.title('Eucidean cost')\n pl.subplot(1,3,2)\n pl.imshow(M2,interpolation='nearest')\n pl.title('Squared Euclidean cost')\n \n pl.subplot(1,3,3)\n pl.imshow(Mp,interpolation='nearest')\n pl.title('Sqrt Euclidean cost')\n #%% EMD\n \n G1=ot.emd(a,b,M1)\n G2=ot.emd(a,b,M2)\n Gp=ot.emd(a,b,Mp)\n \n pl.figure(3+3*data,(15,5))\n \n pl.subplot(1,3,1)\n ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1])\n pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n pl.axis('equal')\n #pl.legend(loc=0)\n pl.title('OT Euclidean')\n \n pl.subplot(1,3,2)\n \n ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1])\n pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n pl.axis('equal')\n #pl.legend(loc=0)\n pl.title('OT squared Euclidean')\n \n pl.subplot(1,3,3)\n \n ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1])\n pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n pl.axis('equal')\n #pl.legend(loc=0)\n pl.title('OT sqrt Euclidean')" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "name": "python2", + "language": "python" + }, + "language_info": { + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "name": "python", + "file_extension": ".py", + "version": "2.7.12", + "pygments_lexer": "ipython2", + "codemirror_mode": { + "version": 2, + "name": "ipython" + } + } + } +}
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_OT_L1_vs_L2.py b/docs/source/auto_examples/plot_OT_L1_vs_L2.py new file mode 100644 index 0000000..9bb92fe --- /dev/null +++ b/docs/source/auto_examples/plot_OT_L1_vs_L2.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +""" +========================================== +2D Optimal transport for different metrics +========================================== + +Stole the figure idea from Fig. 1 and 2 in +https://arxiv.org/pdf/1706.07650.pdf + + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + +#%% parameters and data generation + +for data in range(2): + + if data: + n=20 # nb samples + xs=np.zeros((n,2)) + xs[:,0]=np.arange(n)+1 + xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex... + + xt=np.zeros((n,2)) + xt[:,1]=np.arange(n)+1 + else: + + n=50 # nb samples + xtot=np.zeros((n+1,2)) + xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi) + xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi) + + xs=xtot[:n,:] + xt=xtot[1:,:] + + + + a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples + + # loss matrix + M1=ot.dist(xs,xt,metric='euclidean') + M1/=M1.max() + + # loss matrix + M2=ot.dist(xs,xt,metric='sqeuclidean') + M2/=M2.max() + + # loss matrix + Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean')) + Mp/=Mp.max() + + #%% plot samples + + pl.figure(1+3*data) + pl.clf() + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.axis('equal') + pl.title('Source and traget distributions') + + pl.figure(2+3*data,(15,5)) + pl.subplot(1,3,1) + pl.imshow(M1,interpolation='nearest') + pl.title('Eucidean cost') + pl.subplot(1,3,2) + pl.imshow(M2,interpolation='nearest') + pl.title('Squared Euclidean cost') + + pl.subplot(1,3,3) + pl.imshow(Mp,interpolation='nearest') + pl.title('Sqrt Euclidean cost') + #%% EMD + + G1=ot.emd(a,b,M1) + G2=ot.emd(a,b,M2) + Gp=ot.emd(a,b,Mp) + + pl.figure(3+3*data,(15,5)) + + pl.subplot(1,3,1) + ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1]) + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.axis('equal') + #pl.legend(loc=0) + pl.title('OT Euclidean') + + pl.subplot(1,3,2) + + ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1]) + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.axis('equal') + #pl.legend(loc=0) + pl.title('OT squared Euclidean') + + pl.subplot(1,3,3) + + ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1]) + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.axis('equal') + #pl.legend(loc=0) + pl.title('OT sqrt Euclidean') diff --git a/docs/source/auto_examples/plot_OT_L1_vs_L2.rst b/docs/source/auto_examples/plot_OT_L1_vs_L2.rst new file mode 100644 index 0000000..4e94bef --- /dev/null +++ b/docs/source/auto_examples/plot_OT_L1_vs_L2.rst @@ -0,0 +1,174 @@ + + +.. _sphx_glr_auto_examples_plot_OT_L1_vs_L2.py: + + +========================================== +2D Optimal transport for different metrics +========================================== + +Stole the figure idea from Fig. 1 and 2 in +https://arxiv.org/pdf/1706.07650.pdf + + +@author: rflamary + + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_001.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_003.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_004.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_006.png + :scale: 47 + + + + + +.. code-block:: python + + + import numpy as np + import matplotlib.pylab as pl + import ot + + #%% parameters and data generation + + for data in range(2): + + if data: + n=20 # nb samples + xs=np.zeros((n,2)) + xs[:,0]=np.arange(n)+1 + xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex... + + xt=np.zeros((n,2)) + xt[:,1]=np.arange(n)+1 + else: + + n=50 # nb samples + xtot=np.zeros((n+1,2)) + xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi) + xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi) + + xs=xtot[:n,:] + xt=xtot[1:,:] + + + + a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples + + # loss matrix + M1=ot.dist(xs,xt,metric='euclidean') + M1/=M1.max() + + # loss matrix + M2=ot.dist(xs,xt,metric='sqeuclidean') + M2/=M2.max() + + # loss matrix + Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean')) + Mp/=Mp.max() + + #%% plot samples + + pl.figure(1+3*data) + pl.clf() + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.axis('equal') + pl.title('Source and traget distributions') + + pl.figure(2+3*data,(15,5)) + pl.subplot(1,3,1) + pl.imshow(M1,interpolation='nearest') + pl.title('Eucidean cost') + pl.subplot(1,3,2) + pl.imshow(M2,interpolation='nearest') + pl.title('Squared Euclidean cost') + + pl.subplot(1,3,3) + pl.imshow(Mp,interpolation='nearest') + pl.title('Sqrt Euclidean cost') + #%% EMD + + G1=ot.emd(a,b,M1) + G2=ot.emd(a,b,M2) + Gp=ot.emd(a,b,Mp) + + pl.figure(3+3*data,(15,5)) + + pl.subplot(1,3,1) + ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1]) + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.axis('equal') + #pl.legend(loc=0) + pl.title('OT Euclidean') + + pl.subplot(1,3,2) + + ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1]) + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.axis('equal') + #pl.legend(loc=0) + pl.title('OT squared Euclidean') + + pl.subplot(1,3,3) + + ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1]) + pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') + pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') + pl.axis('equal') + #pl.legend(loc=0) + pl.title('OT sqrt Euclidean') + +**Total running time of the script:** ( 0 minutes 1.417 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_OT_L1_vs_L2.py <plot_OT_L1_vs_L2.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_OT_L1_vs_L2.ipynb <plot_OT_L1_vs_L2.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/plot_OT_conv.ipynb b/docs/source/auto_examples/plot_OT_conv.ipynb new file mode 100644 index 0000000..7fc4af0 --- /dev/null +++ b/docs/source/auto_examples/plot_OT_conv.ipynb @@ -0,0 +1,54 @@ +{ + "nbformat_minor": 0, + "nbformat": 4, + "cells": [ + { + "execution_count": null, + "cell_type": "code", + "source": [ + "%matplotlib inline" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "\n# 1D Wasserstein barycenter demo\n\n\n\n@author: rflamary\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used\nimport scipy as sp\nimport scipy.signal as sps\n#%% parameters\n\nn=10 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\nxx,yy=np.meshgrid(x,x)\n\n\nxpos=np.hstack((xx.reshape(-1,1),yy.reshape(-1,1)))\n\nM=ot.dist(xpos)\n\n\nI0=((xx-5)**2+(yy-5)**2<3**2)*1.0\nI1=((xx-7)**2+(yy-7)**2<3**2)*1.0\n\nI0/=I0.sum()\nI1/=I1.sum()\n\ni0=I0.ravel()\ni1=I1.ravel()\n\nM=M[i0>0,:][:,i1>0].copy()\ni0=i0[i0>0]\ni1=i1[i1>0]\nItot=np.concatenate((I0[:,:,np.newaxis],I1[:,:,np.newaxis]),2)\n\n\n#%% plot the distributions\n\npl.figure(1)\npl.subplot(2,2,1)\npl.imshow(I0)\npl.subplot(2,2,2)\npl.imshow(I1)\n\n\n#%% barycenter computation\n\nalpha=0.5 # 0<=alpha<=1\nweights=np.array([1-alpha,alpha])\n\n\ndef conv2(I,k):\n return sp.ndimage.convolve1d(sp.ndimage.convolve1d(I,k,axis=1),k,axis=0)\n\ndef conv2n(I,k):\n res=np.zeros_like(I)\n for i in range(I.shape[2]):\n res[:,:,i]=conv2(I[:,:,i],k)\n return res\n\n\ndef get_1Dkernel(reg,thr=1e-16,wmax=1024):\n w=max(min(wmax,2*int((-np.log(thr)*reg)**(.5))),3)\n x=np.arange(w,dtype=np.float64)\n return np.exp(-((x-w/2)**2)/reg)\n \nthr=1e-16\nreg=1e0\n\nk=get_1Dkernel(reg)\npl.figure(2)\npl.plot(k)\n\nI05=conv2(I0,k)\n\npl.figure(1)\npl.subplot(2,2,1)\npl.imshow(I0)\npl.subplot(2,2,2)\npl.imshow(I05)\n\n#%%\n\nG=ot.emd(i0,i1,M)\nr0=np.sum(M*G)\n\nreg=1e-1\nGs=ot.bregman.sinkhorn_knopp(i0,i1,M,reg=reg)\nrs=np.sum(M*Gs)\n\n#%%\n\ndef mylog(u):\n tmp=np.log(u)\n tmp[np.isnan(tmp)]=0\n return tmp\n\ndef sinkhorn_conv(a,b, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):\n\n\n a=np.asarray(a,dtype=np.float64)\n b=np.asarray(b,dtype=np.float64)\n \n \n if len(b.shape)>2:\n nbb=b.shape[2]\n a=a[:,:,np.newaxis]\n else:\n nbb=0\n \n\n if log:\n log={'err':[]}\n\n # we assume that no distances are null except those of the diagonal of distances\n if nbb:\n u = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(a.shape[:2]))\n v = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(b.shape[:2]))\n a0=1.0/(np.prod(b.shape[:2]))\n else:\n u = np.ones((a.shape[0],a.shape[1]))/(np.prod(a.shape[:2]))\n v = np.ones((a.shape[0],a.shape[1]))/(np.prod(b.shape[:2]))\n a0=1.0/(np.prod(b.shape[:2]))\n \n \n k=get_1Dkernel(reg)\n \n if nbb:\n K=lambda I: conv2n(I,k)\n else:\n K=lambda I: conv2(I,k)\n\n cpt = 0\n err=1\n while (err>stopThr and cpt<numItermax):\n uprev = u\n vprev = v\n \n v = np.divide(b, K(u))\n u = np.divide(a, K(v))\n\n if (np.any(np.isnan(u)) or np.any(np.isnan(v)) \n or np.any(np.isinf(u)) or np.any(np.isinf(v))):\n # we have reached the machine precision\n # come back to previous solution and quit loop\n print('Warning: numerical errors at iteration', cpt)\n u = uprev\n v = vprev\n break\n if cpt%10==0:\n # we can speed up the process by checking for the error only all the 10th iterations\n\n err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2)\n\n if log:\n log['err'].append(err)\n\n if verbose:\n if cpt%200 ==0:\n print('{:5s}|{:12s}'.format('It.','Err')+'\\n'+'-'*19)\n print('{:5d}|{:8e}|'.format(cpt,err))\n cpt = cpt +1\n if log:\n log['u']=u\n log['v']=v\n \n if nbb: #return only loss \n res=np.zeros((nbb))\n for i in range(nbb):\n res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M)\n if log:\n return res,log\n else:\n return res \n \n else: # return OT matrix\n res=reg*a0*np.sum(a*mylog(u+(u==0))+b*mylog(v+(v==0)))\n if log:\n \n return res,log\n else:\n return res\n\nreg=1e0\nr,log=sinkhorn_conv(I0,I1,reg,verbose=True,log=True)\na=I0\nb=I1\nu=log['u']\nv=log['v']\n#%% barycenter interpolation" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "name": "python2", + "language": "python" + }, + "language_info": { + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "name": "python", + "file_extension": ".py", + "version": "2.7.12", + "pygments_lexer": "ipython2", + "codemirror_mode": { + "version": 2, + "name": "ipython" + } + } + } +}
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_OT_conv.py b/docs/source/auto_examples/plot_OT_conv.py new file mode 100644 index 0000000..a86e7a2 --- /dev/null +++ b/docs/source/auto_examples/plot_OT_conv.py @@ -0,0 +1,200 @@ +# -*- coding: utf-8 -*- +""" +============================== +1D Wasserstein barycenter demo +============================== + + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +import ot +from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used +import scipy as sp +import scipy.signal as sps +#%% parameters + +n=10 # nb bins + +# bin positions +x=np.arange(n,dtype=np.float64) + +xx,yy=np.meshgrid(x,x) + + +xpos=np.hstack((xx.reshape(-1,1),yy.reshape(-1,1))) + +M=ot.dist(xpos) + + +I0=((xx-5)**2+(yy-5)**2<3**2)*1.0 +I1=((xx-7)**2+(yy-7)**2<3**2)*1.0 + +I0/=I0.sum() +I1/=I1.sum() + +i0=I0.ravel() +i1=I1.ravel() + +M=M[i0>0,:][:,i1>0].copy() +i0=i0[i0>0] +i1=i1[i1>0] +Itot=np.concatenate((I0[:,:,np.newaxis],I1[:,:,np.newaxis]),2) + + +#%% plot the distributions + +pl.figure(1) +pl.subplot(2,2,1) +pl.imshow(I0) +pl.subplot(2,2,2) +pl.imshow(I1) + + +#%% barycenter computation + +alpha=0.5 # 0<=alpha<=1 +weights=np.array([1-alpha,alpha]) + + +def conv2(I,k): + return sp.ndimage.convolve1d(sp.ndimage.convolve1d(I,k,axis=1),k,axis=0) + +def conv2n(I,k): + res=np.zeros_like(I) + for i in range(I.shape[2]): + res[:,:,i]=conv2(I[:,:,i],k) + return res + + +def get_1Dkernel(reg,thr=1e-16,wmax=1024): + w=max(min(wmax,2*int((-np.log(thr)*reg)**(.5))),3) + x=np.arange(w,dtype=np.float64) + return np.exp(-((x-w/2)**2)/reg) + +thr=1e-16 +reg=1e0 + +k=get_1Dkernel(reg) +pl.figure(2) +pl.plot(k) + +I05=conv2(I0,k) + +pl.figure(1) +pl.subplot(2,2,1) +pl.imshow(I0) +pl.subplot(2,2,2) +pl.imshow(I05) + +#%% + +G=ot.emd(i0,i1,M) +r0=np.sum(M*G) + +reg=1e-1 +Gs=ot.bregman.sinkhorn_knopp(i0,i1,M,reg=reg) +rs=np.sum(M*Gs) + +#%% + +def mylog(u): + tmp=np.log(u) + tmp[np.isnan(tmp)]=0 + return tmp + +def sinkhorn_conv(a,b, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): + + + a=np.asarray(a,dtype=np.float64) + b=np.asarray(b,dtype=np.float64) + + + if len(b.shape)>2: + nbb=b.shape[2] + a=a[:,:,np.newaxis] + else: + nbb=0 + + + if log: + log={'err':[]} + + # we assume that no distances are null except those of the diagonal of distances + if nbb: + u = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(a.shape[:2])) + v = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(b.shape[:2])) + a0=1.0/(np.prod(b.shape[:2])) + else: + u = np.ones((a.shape[0],a.shape[1]))/(np.prod(a.shape[:2])) + v = np.ones((a.shape[0],a.shape[1]))/(np.prod(b.shape[:2])) + a0=1.0/(np.prod(b.shape[:2])) + + + k=get_1Dkernel(reg) + + if nbb: + K=lambda I: conv2n(I,k) + else: + K=lambda I: conv2(I,k) + + cpt = 0 + err=1 + while (err>stopThr and cpt<numItermax): + uprev = u + vprev = v + + v = np.divide(b, K(u)) + u = np.divide(a, K(v)) + + if (np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + print('Warning: numerical errors at iteration', cpt) + u = uprev + v = vprev + break + if cpt%10==0: + # we can speed up the process by checking for the error only all the 10th iterations + + err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2) + + if log: + log['err'].append(err) + + if verbose: + if cpt%200 ==0: + print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) + print('{:5d}|{:8e}|'.format(cpt,err)) + cpt = cpt +1 + if log: + log['u']=u + log['v']=v + + if nbb: #return only loss + res=np.zeros((nbb)) + for i in range(nbb): + res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M) + if log: + return res,log + else: + return res + + else: # return OT matrix + res=reg*a0*np.sum(a*mylog(u+(u==0))+b*mylog(v+(v==0))) + if log: + + return res,log + else: + return res + +reg=1e0 +r,log=sinkhorn_conv(I0,I1,reg,verbose=True,log=True) +a=I0 +b=I1 +u=log['u'] +v=log['v'] +#%% barycenter interpolation diff --git a/docs/source/auto_examples/plot_OT_conv.rst b/docs/source/auto_examples/plot_OT_conv.rst new file mode 100644 index 0000000..039bbdb --- /dev/null +++ b/docs/source/auto_examples/plot_OT_conv.rst @@ -0,0 +1,241 @@ + + +.. _sphx_glr_auto_examples_plot_OT_conv.py: + + +============================== +1D Wasserstein barycenter demo +============================== + + +@author: rflamary + + + + +.. code-block:: pytb + + Traceback (most recent call last): + File "/home/rflamary/.local/lib/python2.7/site-packages/sphinx_gallery/gen_rst.py", line 518, in execute_code_block + exec(code_block, example_globals) + File "<string>", line 86, in <module> + TypeError: unsupported operand type(s) for *: 'float' and 'Mock' + + + + + +.. code-block:: python + + + import numpy as np + import matplotlib.pylab as pl + import ot + from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used + import scipy as sp + import scipy.signal as sps + #%% parameters + + n=10 # nb bins + + # bin positions + x=np.arange(n,dtype=np.float64) + + xx,yy=np.meshgrid(x,x) + + + xpos=np.hstack((xx.reshape(-1,1),yy.reshape(-1,1))) + + M=ot.dist(xpos) + + + I0=((xx-5)**2+(yy-5)**2<3**2)*1.0 + I1=((xx-7)**2+(yy-7)**2<3**2)*1.0 + + I0/=I0.sum() + I1/=I1.sum() + + i0=I0.ravel() + i1=I1.ravel() + + M=M[i0>0,:][:,i1>0].copy() + i0=i0[i0>0] + i1=i1[i1>0] + Itot=np.concatenate((I0[:,:,np.newaxis],I1[:,:,np.newaxis]),2) + + + #%% plot the distributions + + pl.figure(1) + pl.subplot(2,2,1) + pl.imshow(I0) + pl.subplot(2,2,2) + pl.imshow(I1) + + + #%% barycenter computation + + alpha=0.5 # 0<=alpha<=1 + weights=np.array([1-alpha,alpha]) + + + def conv2(I,k): + return sp.ndimage.convolve1d(sp.ndimage.convolve1d(I,k,axis=1),k,axis=0) + + def conv2n(I,k): + res=np.zeros_like(I) + for i in range(I.shape[2]): + res[:,:,i]=conv2(I[:,:,i],k) + return res + + + def get_1Dkernel(reg,thr=1e-16,wmax=1024): + w=max(min(wmax,2*int((-np.log(thr)*reg)**(.5))),3) + x=np.arange(w,dtype=np.float64) + return np.exp(-((x-w/2)**2)/reg) + + thr=1e-16 + reg=1e0 + + k=get_1Dkernel(reg) + pl.figure(2) + pl.plot(k) + + I05=conv2(I0,k) + + pl.figure(1) + pl.subplot(2,2,1) + pl.imshow(I0) + pl.subplot(2,2,2) + pl.imshow(I05) + + #%% + + G=ot.emd(i0,i1,M) + r0=np.sum(M*G) + + reg=1e-1 + Gs=ot.bregman.sinkhorn_knopp(i0,i1,M,reg=reg) + rs=np.sum(M*Gs) + + #%% + + def mylog(u): + tmp=np.log(u) + tmp[np.isnan(tmp)]=0 + return tmp + + def sinkhorn_conv(a,b, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): + + + a=np.asarray(a,dtype=np.float64) + b=np.asarray(b,dtype=np.float64) + + + if len(b.shape)>2: + nbb=b.shape[2] + a=a[:,:,np.newaxis] + else: + nbb=0 + + + if log: + log={'err':[]} + + # we assume that no distances are null except those of the diagonal of distances + if nbb: + u = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(a.shape[:2])) + v = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(b.shape[:2])) + a0=1.0/(np.prod(b.shape[:2])) + else: + u = np.ones((a.shape[0],a.shape[1]))/(np.prod(a.shape[:2])) + v = np.ones((a.shape[0],a.shape[1]))/(np.prod(b.shape[:2])) + a0=1.0/(np.prod(b.shape[:2])) + + + k=get_1Dkernel(reg) + + if nbb: + K=lambda I: conv2n(I,k) + else: + K=lambda I: conv2(I,k) + + cpt = 0 + err=1 + while (err>stopThr and cpt<numItermax): + uprev = u + vprev = v + + v = np.divide(b, K(u)) + u = np.divide(a, K(v)) + + if (np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + print('Warning: numerical errors at iteration', cpt) + u = uprev + v = vprev + break + if cpt%10==0: + # we can speed up the process by checking for the error only all the 10th iterations + + err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2) + + if log: + log['err'].append(err) + + if verbose: + if cpt%200 ==0: + print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) + print('{:5d}|{:8e}|'.format(cpt,err)) + cpt = cpt +1 + if log: + log['u']=u + log['v']=v + + if nbb: #return only loss + res=np.zeros((nbb)) + for i in range(nbb): + res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M) + if log: + return res,log + else: + return res + + else: # return OT matrix + res=reg*a0*np.sum(a*mylog(u+(u==0))+b*mylog(v+(v==0))) + if log: + + return res,log + else: + return res + + reg=1e0 + r,log=sinkhorn_conv(I0,I1,reg,verbose=True,log=True) + a=I0 + b=I1 + u=log['u'] + v=log['v'] + #%% barycenter interpolation + +**Total running time of the script:** ( 0 minutes 0.000 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_OT_conv.py <plot_OT_conv.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_OT_conv.ipynb <plot_OT_conv.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/plot_WDA.ipynb b/docs/source/auto_examples/plot_WDA.ipynb index 6d641a7..408a605 100644 --- a/docs/source/auto_examples/plot_WDA.ipynb +++ b/docs/source/auto_examples/plot_WDA.ipynb @@ -15,7 +15,7 @@ }, { "source": [ - "\n# WAsserstein Discriminant Analysis\n\n\n@author: rflamary\n\n" + "\n# Wasserstein Discriminant Analysis\n\n\n@author: rflamary\n\n" ], "cell_type": "markdown", "metadata": {} diff --git a/docs/source/auto_examples/plot_WDA.py b/docs/source/auto_examples/plot_WDA.py index 94b7ef4..bbe3888 100644 --- a/docs/source/auto_examples/plot_WDA.py +++ b/docs/source/auto_examples/plot_WDA.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ ================================= -WAsserstein Discriminant Analysis +Wasserstein Discriminant Analysis ================================= @author: rflamary diff --git a/docs/source/auto_examples/plot_WDA.rst b/docs/source/auto_examples/plot_WDA.rst index 379a133..540555d 100644 --- a/docs/source/auto_examples/plot_WDA.rst +++ b/docs/source/auto_examples/plot_WDA.rst @@ -4,7 +4,7 @@ ================================= -WAsserstein Discriminant Analysis +Wasserstein Discriminant Analysis ================================= @author: rflamary @@ -23,23 +23,26 @@ WAsserstein Discriminant Analysis Compiling cost function... Computing gradient of cost function... iter cost val grad. norm - 1 +7.5272200933021116e-01 8.85804426e-01 - 2 +2.5764223980223788e-01 3.04501586e-01 - 3 +1.6018169776696620e-01 1.78298483e-01 - 4 +1.4560944642106255e-01 1.42133298e-01 - 5 +1.0243843483991794e-01 1.23342675e-01 - 6 +7.8856617504010643e-02 1.05379766e-01 - 7 +7.7620851864404483e-02 1.04044062e-01 - 8 +7.3160520861018416e-02 8.33770034e-02 - 9 +6.6999294576662857e-02 2.87368977e-02 - 10 +6.6250206928793964e-02 1.72155066e-03 - 11 +6.6247631521353170e-02 2.43806911e-04 - 12 +6.6247596955965438e-02 1.40066459e-04 - 13 +6.6247580176638649e-02 4.77471577e-06 - 14 +6.6247580163923028e-02 3.00484279e-06 - 15 +6.6247580159235792e-02 1.91039983e-06 - 16 +6.6247580156889613e-02 9.56038747e-07 - Terminated - min grad norm reached after 16 iterations, 7.78 seconds. + 1 +5.2427396265941129e-01 8.16627951e-01 + 2 +1.7904850059627236e-01 1.91366819e-01 + 3 +1.6985797253002377e-01 1.70940682e-01 + 4 +1.3903474972292729e-01 1.28606342e-01 + 5 +7.4961734618782416e-02 6.41973980e-02 + 6 +7.1900245222486239e-02 4.25693592e-02 + 7 +7.0472023318269614e-02 2.34599232e-02 + 8 +6.9917568641317152e-02 5.66542766e-03 + 9 +6.9885086242452696e-02 4.05756115e-04 + 10 +6.9884967432653489e-02 2.16836017e-04 + 11 +6.9884923649884148e-02 5.74961622e-05 + 12 +6.9884921818258436e-02 3.83257203e-05 + 13 +6.9884920459612282e-02 9.97486224e-06 + 14 +6.9884920414414409e-02 7.33567875e-06 + 15 +6.9884920388431387e-02 5.23889187e-06 + 16 +6.9884920385183902e-02 4.91959084e-06 + 17 +6.9884920373983223e-02 3.56451669e-06 + 18 +6.9884920369701245e-02 2.88858709e-06 + 19 +6.9884920361621208e-02 1.82294279e-07 + Terminated - min grad norm reached after 19 iterations, 9.65 seconds. @@ -105,7 +108,7 @@ WAsserstein Discriminant Analysis pl.legend(loc=0) pl.title('Projected test samples') -**Total running time of the script:** ( 0 minutes 14.134 seconds) +**Total running time of the script:** ( 0 minutes 16.902 seconds) diff --git a/docs/source/auto_examples/plot_compute_emd.ipynb b/docs/source/auto_examples/plot_compute_emd.ipynb new file mode 100644 index 0000000..4162144 --- /dev/null +++ b/docs/source/auto_examples/plot_compute_emd.ipynb @@ -0,0 +1,54 @@ +{ + "nbformat_minor": 0, + "nbformat": 4, + "cells": [ + { + "execution_count": null, + "cell_type": "code", + "source": [ + "%matplotlib inline" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + }, + { + "source": [ + "\n# 1D optimal transport\n\n\n@author: rflamary\n\n" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "execution_count": null, + "cell_type": "code", + "source": [ + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss\n\n\n#%% parameters\n\nn=100 # nb bins\nn_target=50 # nb target distributions\n\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\nlst_m=np.linspace(20,90,n_target)\n\n# Gaussian distributions\na=gauss(n,m=20,s=5) # m= mean, s= std\n\nB=np.zeros((n,n_target))\n\nfor i,m in enumerate(lst_m):\n B[:,i]=gauss(n,m=m,s=5)\n\n# loss matrix and normalization\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean')\nM/=M.max()\nM2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean')\nM2/=M2.max()\n#%% plot the distributions\n\npl.figure(1)\npl.subplot(2,1,1)\npl.plot(x,a,'b',label='Source distribution')\npl.title('Source distribution')\npl.subplot(2,1,2)\npl.plot(x,B,label='Target distributions')\npl.title('Target distributions')\n\n#%% Compute and plot distributions and loss matrix\n\nd_emd=ot.emd2(a,B,M) # direct computation of EMD\nd_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3\n\n\npl.figure(2)\npl.plot(d_emd,label='Euclidean EMD')\npl.plot(d_emd2,label='Squared Euclidean EMD')\npl.title('EMD distances')\npl.legend()\n\n#%%\nreg=1e-2\nd_sinkhorn=ot.sinkhorn(a,B,M,reg)\nd_sinkhorn2=ot.sinkhorn(a,B,M2,reg)\n\npl.figure(2)\npl.clf()\npl.plot(d_emd,label='Euclidean EMD')\npl.plot(d_emd2,label='Squared Euclidean EMD')\npl.plot(d_sinkhorn,'+',label='Euclidean Sinkhorn')\npl.plot(d_sinkhorn2,'+',label='Squared Euclidean Sinkhorn')\npl.title('EMD distances')\npl.legend()" + ], + "outputs": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "name": "python2", + "language": "python" + }, + "language_info": { + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "name": "python", + "file_extension": ".py", + "version": "2.7.12", + "pygments_lexer": "ipython2", + "codemirror_mode": { + "version": 2, + "name": "ipython" + } + } + } +}
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_compute_emd.py b/docs/source/auto_examples/plot_compute_emd.py new file mode 100644 index 0000000..c7063e8 --- /dev/null +++ b/docs/source/auto_examples/plot_compute_emd.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +""" +==================== +1D optimal transport +==================== + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +import ot +from ot.datasets import get_1D_gauss as gauss + + +#%% parameters + +n=100 # nb bins +n_target=50 # nb target distributions + + +# bin positions +x=np.arange(n,dtype=np.float64) + +lst_m=np.linspace(20,90,n_target) + +# Gaussian distributions +a=gauss(n,m=20,s=5) # m= mean, s= std + +B=np.zeros((n,n_target)) + +for i,m in enumerate(lst_m): + B[:,i]=gauss(n,m=m,s=5) + +# loss matrix and normalization +M=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean') +M/=M.max() +M2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean') +M2/=M2.max() +#%% plot the distributions + +pl.figure(1) +pl.subplot(2,1,1) +pl.plot(x,a,'b',label='Source distribution') +pl.title('Source distribution') +pl.subplot(2,1,2) +pl.plot(x,B,label='Target distributions') +pl.title('Target distributions') + +#%% Compute and plot distributions and loss matrix + +d_emd=ot.emd2(a,B,M) # direct computation of EMD +d_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3 + + +pl.figure(2) +pl.plot(d_emd,label='Euclidean EMD') +pl.plot(d_emd2,label='Squared Euclidean EMD') +pl.title('EMD distances') +pl.legend() + +#%% +reg=1e-2 +d_sinkhorn=ot.sinkhorn(a,B,M,reg) +d_sinkhorn2=ot.sinkhorn(a,B,M2,reg) + +pl.figure(2) +pl.clf() +pl.plot(d_emd,label='Euclidean EMD') +pl.plot(d_emd2,label='Squared Euclidean EMD') +pl.plot(d_sinkhorn,'+',label='Euclidean Sinkhorn') +pl.plot(d_sinkhorn2,'+',label='Squared Euclidean Sinkhorn') +pl.title('EMD distances') +pl.legend()
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_compute_emd.rst b/docs/source/auto_examples/plot_compute_emd.rst new file mode 100644 index 0000000..4c7445b --- /dev/null +++ b/docs/source/auto_examples/plot_compute_emd.rst @@ -0,0 +1,119 @@ + + +.. _sphx_glr_auto_examples_plot_compute_emd.py: + + +==================== +1D optimal transport +==================== + +@author: rflamary + + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_compute_emd_001.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_compute_emd_002.png + :scale: 47 + + + + + +.. code-block:: python + + + import numpy as np + import matplotlib.pylab as pl + import ot + from ot.datasets import get_1D_gauss as gauss + + + #%% parameters + + n=100 # nb bins + n_target=50 # nb target distributions + + + # bin positions + x=np.arange(n,dtype=np.float64) + + lst_m=np.linspace(20,90,n_target) + + # Gaussian distributions + a=gauss(n,m=20,s=5) # m= mean, s= std + + B=np.zeros((n,n_target)) + + for i,m in enumerate(lst_m): + B[:,i]=gauss(n,m=m,s=5) + + # loss matrix and normalization + M=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean') + M/=M.max() + M2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean') + M2/=M2.max() + #%% plot the distributions + + pl.figure(1) + pl.subplot(2,1,1) + pl.plot(x,a,'b',label='Source distribution') + pl.title('Source distribution') + pl.subplot(2,1,2) + pl.plot(x,B,label='Target distributions') + pl.title('Target distributions') + + #%% Compute and plot distributions and loss matrix + + d_emd=ot.emd2(a,B,M) # direct computation of EMD + d_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3 + + + pl.figure(2) + pl.plot(d_emd,label='Euclidean EMD') + pl.plot(d_emd2,label='Squared Euclidean EMD') + pl.title('EMD distances') + pl.legend() + + #%% + reg=1e-2 + d_sinkhorn=ot.sinkhorn(a,B,M,reg) + d_sinkhorn2=ot.sinkhorn(a,B,M2,reg) + + pl.figure(2) + pl.clf() + pl.plot(d_emd,label='Euclidean EMD') + pl.plot(d_emd2,label='Squared Euclidean EMD') + pl.plot(d_sinkhorn,'+',label='Euclidean Sinkhorn') + pl.plot(d_sinkhorn2,'+',label='Squared Euclidean Sinkhorn') + pl.title('EMD distances') + pl.legend() +**Total running time of the script:** ( 0 minutes 0.521 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_compute_emd.py <plot_compute_emd.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_compute_emd.ipynb <plot_compute_emd.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/plot_optim_OTreg.ipynb b/docs/source/auto_examples/plot_optim_OTreg.ipynb index 00d4702..5ded922 100644 --- a/docs/source/auto_examples/plot_optim_OTreg.ipynb +++ b/docs/source/auto_examples/plot_optim_OTreg.ipynb @@ -24,7 +24,7 @@ "execution_count": null, "cell_type": "code", "source": [ - "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n\n\n#%% parameters\n\nn=100 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\n# Gaussian distributions\na=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std\nb=ot.datasets.get_1D_gauss(n,m=60,s=10)\n\n# loss matrix\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)))\nM/=M.max()\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,G0,'OT matrix G0')\n\n#%% Example with Frobenius norm regularization\n\ndef f(G): return 0.5*np.sum(G**2)\ndef df(G): return G\n\nreg=1e-1\n\nGl2=ot.optim.cg(a,b,M,reg,f,df,verbose=True)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,Gl2,'OT matrix Frob. reg')\n\n#%% Example with entropic regularization\n\ndef f(G): return np.sum(G*np.log(G))\ndef df(G): return np.log(G)+1\n\nreg=1e-3\n\nGe=ot.optim.cg(a,b,M,reg,f,df,verbose=True)\n\npl.figure(4)\not.plot.plot1D_mat(a,b,Ge,'OT matrix Entrop. reg')\n\n#%% Example with Frobenius norm + entropic regularization with gcg\n\ndef f(G): return 0.5*np.sum(G**2)\ndef df(G): return G\n\nreg1=1e-3\nreg2=1e-1\n\nGel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True)\n\npl.figure(5)\not.plot.plot1D_mat(a,b,Gel2,'OT entropic + matrix Frob. reg')" + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n\n\n#%% parameters\n\nn=100 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\n# Gaussian distributions\na=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std\nb=ot.datasets.get_1D_gauss(n,m=60,s=10)\n\n# loss matrix\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)))\nM/=M.max()\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,G0,'OT matrix G0')\n\n#%% Example with Frobenius norm regularization\n\ndef f(G): return 0.5*np.sum(G**2)\ndef df(G): return G\n\nreg=1e-1\n\nGl2=ot.optim.cg(a,b,M,reg,f,df,verbose=True)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,Gl2,'OT matrix Frob. reg')\n\n#%% Example with entropic regularization\n\ndef f(G): return np.sum(G*np.log(G))\ndef df(G): return np.log(G)+1\n\nreg=1e-3\n\nGe=ot.optim.cg(a,b,M,reg,f,df,verbose=True)\n\npl.figure(4)\not.plot.plot1D_mat(a,b,Ge,'OT matrix Entrop. reg')\n\n#%% Example with Frobenius norm + entropic regularization with gcg\n\ndef f(G): return 0.5*np.sum(G**2)\ndef df(G): return G\n\nreg1=1e-3\nreg2=1e-1\n\nGel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True)\n\npl.figure(5)\not.plot.plot1D_mat(a,b,Gel2,'OT entropic + matrix Frob. reg')\npl.show()" ], "outputs": [], "metadata": { diff --git a/docs/source/auto_examples/plot_optim_OTreg.py b/docs/source/auto_examples/plot_optim_OTreg.py index c585445..8abb426 100644 --- a/docs/source/auto_examples/plot_optim_OTreg.py +++ b/docs/source/auto_examples/plot_optim_OTreg.py @@ -70,4 +70,5 @@ reg2=1e-1 Gel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True) pl.figure(5) -ot.plot.plot1D_mat(a,b,Gel2,'OT entropic + matrix Frob. reg')
\ No newline at end of file +ot.plot.plot1D_mat(a,b,Gel2,'OT entropic + matrix Frob. reg') +pl.show()
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_optim_OTreg.rst b/docs/source/auto_examples/plot_optim_OTreg.rst index 0dff327..70cd26c 100644 --- a/docs/source/auto_examples/plot_optim_OTreg.rst +++ b/docs/source/auto_examples/plot_optim_OTreg.rst @@ -562,7 +562,8 @@ Regularized OT with generic solver pl.figure(5) ot.plot.plot1D_mat(a,b,Gel2,'OT entropic + matrix Frob. reg') -**Total running time of the script:** ( 0 minutes 2.422 seconds) + pl.show() +**Total running time of the script:** ( 0 minutes 2.319 seconds) diff --git a/docs/source/readme.rst b/docs/source/readme.rst index 13cf572..625cebf 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -11,7 +11,8 @@ It provides the following solvers: - OT solver for the linear program/ Earth Movers Distance [1]. - Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] - and stabilized version [9][10]. + and stabilized version [9][10] with optional GPU implementation + (required cudamat). - Bregman projections for Wasserstein barycenter [3] and unmixing [4]. - Optimal transport for domain adaptation with group lasso regularization [5] @@ -71,14 +72,14 @@ Dependencies Some sub-modules require additional dependences which are discussed below -- ot.dr (Wasserstein dimensionality rediuction) depends on autograd and - pymanopt that can be installed with: +- **ot.dr** (Wasserstein dimensionality rediuction) depends on autograd + and pymanopt that can be installed with: :: pip install pymanopt autograd -- ot.gpu (GPU accelerated OT) depends on cudamat that have to be +- **ot.gpu** (GPU accelerated OT) depends on cudamat that have to be installed with: :: @@ -87,6 +88,8 @@ below cd cudamat python setup.py install --user # for user install (no root) +obviously you need CUDA installed and a compatible GPU. + Examples -------- @@ -144,47 +147,59 @@ References ---------- [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, -December). Displacement interpolation using Lagrangian mass transport. +December). `Displacement interpolation using Lagrangian mass +transport <https://people.csail.mit.edu/sparis/publi/2011/sigasia/Bonneel_11_Displacement_Interpolation.pdf>`__. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. -[2] Cuturi, M. (2013). Sinkhorn distances: Lightspeed computation of -optimal transport. In Advances in Neural Information Processing Systems -(pp. 2292-2300). +[2] Cuturi, M. (2013). `Sinkhorn distances: Lightspeed computation of +optimal transport <https://arxiv.org/pdf/1306.0895.pdf>`__. In Advances +in Neural Information Processing Systems (pp. 2292-2300). [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. -(2015). Iterative Bregman projections for regularized transportation -problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. +(2015). `Iterative Bregman projections for regularized transportation +problems <https://arxiv.org/pdf/1412.5154.pdf>`__. SIAM Journal on +Scientific Computing, 37(2), A1111-A1138. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, -Supervised planetary unmixing with optimal transport, Whorkshop on -Hyperspectral Image and Signal Processing : Evolution in Remote Sensing -(WHISPERS), 2016. +`Supervised planetary unmixing with optimal +transport <https://hal.archives-ouvertes.fr/hal-01377236/document>`__, +Whorkshop on Hyperspectral Image and Signal Processing : Evolution in +Remote Sensing (WHISPERS), 2016. -[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport -for Domain Adaptation," in IEEE Transactions on Pattern Analysis and -Machine Intelligence , vol.PP, no.99, pp.1-1 +[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, `Optimal Transport +for Domain Adaptation <https://arxiv.org/pdf/1507.00504.pdf>`__, in IEEE +Transactions on Pattern Analysis and Machine Intelligence , vol.PP, +no.99, pp.1-1 [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). -Regularized discrete optimal transport. SIAM Journal on Imaging -Sciences, 7(3), 1853-1882. +`Regularized discrete optimal +transport <https://arxiv.org/pdf/1307.5551.pdf>`__. SIAM Journal on +Imaging Sciences, 7(3), 1853-1882. -[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized -conditional gradient: analysis of convergence and applications. arXiv -preprint arXiv:1510.06567. +[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). `Generalized +conditional gradient: analysis of convergence and +applications <https://arxiv.org/pdf/1510.06567.pdf>`__. arXiv preprint +arXiv:1510.06567. -[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation -for discrete optimal transport", Neural Information Processing Systems -(NIPS), 2016. +[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, `Mapping estimation +for discrete optimal +transport <http://remi.flamary.com/biblio/perrot2016mapping.pdf>`__, +Neural Information Processing Systems (NIPS), 2016. -[9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for -Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. +[9] Schmitzer, B. (2016). `Stabilized Sparse Scaling Algorithms for +Entropy Regularized Transport +Problems <https://arxiv.org/pdf/1610.06519.pdf>`__. arXiv preprint +arXiv:1610.06519. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). -Scaling algorithms for unbalanced transport problems. arXiv preprint +`Scaling algorithms for unbalanced transport +problems <https://arxiv.org/pdf/1607.05816.pdf>`__. arXiv preprint arXiv:1607.05816. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). -Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063. +`Wasserstein Discriminant +Analysis <https://arxiv.org/pdf/1608.08063.pdf>`__. arXiv preprint +arXiv:1608.08063. .. |PyPI version| image:: https://badge.fury.io/py/POT.svg :target: https://badge.fury.io/py/POT |