diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-03 16:33:00 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-03 16:33:00 +0200 |
commit | adbf95ee9a720fa38b5b91d0a9d5c3b22ba0b226 (patch) | |
tree | 834a6934ab09c4daec751c220ba847e8be61bb4b /docs/source/auto_examples/plot_OT_2D_samples.ipynb | |
parent | f639518e9b96c5904122e62e024ed4ae369ceb33 (diff) |
update doc
Diffstat (limited to 'docs/source/auto_examples/plot_OT_2D_samples.ipynb')
-rw-r--r-- | docs/source/auto_examples/plot_OT_2D_samples.ipynb | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/docs/source/auto_examples/plot_OT_2D_samples.ipynb b/docs/source/auto_examples/plot_OT_2D_samples.ipynb index 7d42ba7..fad0467 100644 --- a/docs/source/auto_examples/plot_OT_2D_samples.ipynb +++ b/docs/source/auto_examples/plot_OT_2D_samples.ipynb @@ -24,7 +24,7 @@ "execution_count": null, "cell_type": "code", "source": [ - "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n#%% parameters and data generation\n\nn=2 # nb samples\n\nmu_s=np.array([0,0])\ncov_s=np.array([[1,0],[0,1]])\n\nmu_t=np.array([4,4])\ncov_t=np.array([[1,-.8],[-.8,1]])\n\nxs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)\nxt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)\n\na,b = ot.unif(n),ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM=ot.dist(xs,xt)\nM/=M.max()\n\n#%% plot samples\n\npl.figure(1)\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('Source and traget distributions')\n\npl.figure(2)\npl.imshow(M,interpolation='nearest')\npl.title('Cost matrix M')\n\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\npl.imshow(G0,interpolation='nearest')\npl.title('OT matrix G0')\n\npl.figure(4)\not.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix with samples')\n\n\n#%% sinkhorn\n\n# reg term\nlambd=5e-3\n\nGs=ot.sinkhorn(a,b,M,lambd)\n\npl.figure(5)\npl.imshow(Gs,interpolation='nearest')\npl.title('OT matrix sinkhorn')\n\npl.figure(6)\not.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn with samples')" + "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n#%% parameters and data generation\n\nn=50 # nb samples\n\nmu_s=np.array([0,0])\ncov_s=np.array([[1,0],[0,1]])\n\nmu_t=np.array([4,4])\ncov_t=np.array([[1,-.8],[-.8,1]])\n\nxs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)\nxt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)\n\na,b = ot.unif(n),ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM=ot.dist(xs,xt)\nM/=M.max()\n\n#%% plot samples\n\npl.figure(1)\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('Source and traget distributions')\n\npl.figure(2)\npl.imshow(M,interpolation='nearest')\npl.title('Cost matrix M')\n\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\npl.imshow(G0,interpolation='nearest')\npl.title('OT matrix G0')\n\npl.figure(4)\not.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix with samples')\n\n\n#%% sinkhorn\n\n# reg term\nlambd=5e-4\n\nGs=ot.sinkhorn(a,b,M,lambd)\n\npl.figure(5)\npl.imshow(Gs,interpolation='nearest')\npl.title('OT matrix sinkhorn')\n\npl.figure(6)\not.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn with samples')" ], "outputs": [], "metadata": { |