summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_OT_2D_samples.ipynb
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-03 16:33:00 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-03 16:33:00 +0200
commitadbf95ee9a720fa38b5b91d0a9d5c3b22ba0b226 (patch)
tree834a6934ab09c4daec751c220ba847e8be61bb4b /docs/source/auto_examples/plot_OT_2D_samples.ipynb
parentf639518e9b96c5904122e62e024ed4ae369ceb33 (diff)
update doc
Diffstat (limited to 'docs/source/auto_examples/plot_OT_2D_samples.ipynb')
-rw-r--r--docs/source/auto_examples/plot_OT_2D_samples.ipynb2
1 files changed, 1 insertions, 1 deletions
diff --git a/docs/source/auto_examples/plot_OT_2D_samples.ipynb b/docs/source/auto_examples/plot_OT_2D_samples.ipynb
index 7d42ba7..fad0467 100644
--- a/docs/source/auto_examples/plot_OT_2D_samples.ipynb
+++ b/docs/source/auto_examples/plot_OT_2D_samples.ipynb
@@ -24,7 +24,7 @@
"execution_count": null,
"cell_type": "code",
"source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n#%% parameters and data generation\n\nn=2 # nb samples\n\nmu_s=np.array([0,0])\ncov_s=np.array([[1,0],[0,1]])\n\nmu_t=np.array([4,4])\ncov_t=np.array([[1,-.8],[-.8,1]])\n\nxs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)\nxt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)\n\na,b = ot.unif(n),ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM=ot.dist(xs,xt)\nM/=M.max()\n\n#%% plot samples\n\npl.figure(1)\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('Source and traget distributions')\n\npl.figure(2)\npl.imshow(M,interpolation='nearest')\npl.title('Cost matrix M')\n\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\npl.imshow(G0,interpolation='nearest')\npl.title('OT matrix G0')\n\npl.figure(4)\not.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix with samples')\n\n\n#%% sinkhorn\n\n# reg term\nlambd=5e-3\n\nGs=ot.sinkhorn(a,b,M,lambd)\n\npl.figure(5)\npl.imshow(Gs,interpolation='nearest')\npl.title('OT matrix sinkhorn')\n\npl.figure(6)\not.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn with samples')"
+ "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n#%% parameters and data generation\n\nn=50 # nb samples\n\nmu_s=np.array([0,0])\ncov_s=np.array([[1,0],[0,1]])\n\nmu_t=np.array([4,4])\ncov_t=np.array([[1,-.8],[-.8,1]])\n\nxs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)\nxt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)\n\na,b = ot.unif(n),ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM=ot.dist(xs,xt)\nM/=M.max()\n\n#%% plot samples\n\npl.figure(1)\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('Source and traget distributions')\n\npl.figure(2)\npl.imshow(M,interpolation='nearest')\npl.title('Cost matrix M')\n\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\npl.imshow(G0,interpolation='nearest')\npl.title('OT matrix G0')\n\npl.figure(4)\not.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix with samples')\n\n\n#%% sinkhorn\n\n# reg term\nlambd=5e-4\n\nGs=ot.sinkhorn(a,b,M,lambd)\n\npl.figure(5)\npl.imshow(Gs,interpolation='nearest')\npl.title('OT matrix sinkhorn')\n\npl.figure(6)\not.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn with samples')"
],
"outputs": [],
"metadata": {