diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-08-30 17:01:01 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-08-30 17:01:01 +0200 |
commit | dc8737a30cb6d9f1305173eb8d16fe6716fd1231 (patch) | |
tree | 1f03384de2af88ed07a1e850e0871db826ed53e7 /docs/source/auto_examples/plot_compute_emd.ipynb | |
parent | c2a7a1f3ab4ba5c4f5adeca0fa22d8d6b4fc079d (diff) |
wroking make!
Diffstat (limited to 'docs/source/auto_examples/plot_compute_emd.ipynb')
-rw-r--r-- | docs/source/auto_examples/plot_compute_emd.ipynb | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/docs/source/auto_examples/plot_compute_emd.ipynb b/docs/source/auto_examples/plot_compute_emd.ipynb index 4162144..ce3f8c6 100644 --- a/docs/source/auto_examples/plot_compute_emd.ipynb +++ b/docs/source/auto_examples/plot_compute_emd.ipynb @@ -15,7 +15,7 @@ }, { "source": [ - "\n# 1D optimal transport\n\n\n@author: rflamary\n\n" + "\n# 1D optimal transport\n\n\n\n" ], "cell_type": "markdown", "metadata": {} @@ -24,7 +24,7 @@ "execution_count": null, "cell_type": "code", "source": [ - "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss\n\n\n#%% parameters\n\nn=100 # nb bins\nn_target=50 # nb target distributions\n\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\nlst_m=np.linspace(20,90,n_target)\n\n# Gaussian distributions\na=gauss(n,m=20,s=5) # m= mean, s= std\n\nB=np.zeros((n,n_target))\n\nfor i,m in enumerate(lst_m):\n B[:,i]=gauss(n,m=m,s=5)\n\n# loss matrix and normalization\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean')\nM/=M.max()\nM2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean')\nM2/=M2.max()\n#%% plot the distributions\n\npl.figure(1)\npl.subplot(2,1,1)\npl.plot(x,a,'b',label='Source distribution')\npl.title('Source distribution')\npl.subplot(2,1,2)\npl.plot(x,B,label='Target distributions')\npl.title('Target distributions')\n\n#%% Compute and plot distributions and loss matrix\n\nd_emd=ot.emd2(a,B,M) # direct computation of EMD\nd_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3\n\n\npl.figure(2)\npl.plot(d_emd,label='Euclidean EMD')\npl.plot(d_emd2,label='Squared Euclidean EMD')\npl.title('EMD distances')\npl.legend()\n\n#%%\nreg=1e-2\nd_sinkhorn=ot.sinkhorn(a,B,M,reg)\nd_sinkhorn2=ot.sinkhorn(a,B,M2,reg)\n\npl.figure(2)\npl.clf()\npl.plot(d_emd,label='Euclidean EMD')\npl.plot(d_emd2,label='Squared Euclidean EMD')\npl.plot(d_sinkhorn,'+',label='Euclidean Sinkhorn')\npl.plot(d_sinkhorn2,'+',label='Squared Euclidean Sinkhorn')\npl.title('EMD distances')\npl.legend()" + "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss\n\n\n#%% parameters\n\nn = 100 # nb bins\nn_target = 50 # nb target distributions\n\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\nlst_m = np.linspace(20, 90, n_target)\n\n# Gaussian distributions\na = gauss(n, m=20, s=5) # m= mean, s= std\n\nB = np.zeros((n, n_target))\n\nfor i, m in enumerate(lst_m):\n B[:, i] = gauss(n, m=m, s=5)\n\n# loss matrix and normalization\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')\nM /= M.max()\nM2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')\nM2 /= M2.max()\n#%% plot the distributions\n\npl.figure(1)\npl.subplot(2, 1, 1)\npl.plot(x, a, 'b', label='Source distribution')\npl.title('Source distribution')\npl.subplot(2, 1, 2)\npl.plot(x, B, label='Target distributions')\npl.title('Target distributions')\npl.tight_layout()\n\n#%% Compute and plot distributions and loss matrix\n\nd_emd = ot.emd2(a, B, M) # direct computation of EMD\nd_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M3\n\n\npl.figure(2)\npl.plot(d_emd, label='Euclidean EMD')\npl.plot(d_emd2, label='Squared Euclidean EMD')\npl.title('EMD distances')\npl.legend()\n\n#%%\nreg = 1e-2\nd_sinkhorn = ot.sinkhorn2(a, B, M, reg)\nd_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)\n\npl.figure(2)\npl.clf()\npl.plot(d_emd, label='Euclidean EMD')\npl.plot(d_emd2, label='Squared Euclidean EMD')\npl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')\npl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')\npl.title('EMD distances')\npl.legend()\n\npl.show()" ], "outputs": [], "metadata": { |