summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_OT_2D_samples.rst
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-08-30 17:01:01 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-08-30 17:01:01 +0200
commitdc8737a30cb6d9f1305173eb8d16fe6716fd1231 (patch)
tree1f03384de2af88ed07a1e850e0871db826ed53e7 /docs/source/auto_examples/plot_OT_2D_samples.rst
parentc2a7a1f3ab4ba5c4f5adeca0fa22d8d6b4fc079d (diff)
wroking make!
Diffstat (limited to 'docs/source/auto_examples/plot_OT_2D_samples.rst')
-rw-r--r--docs/source/auto_examples/plot_OT_2D_samples.rst68
1 files changed, 32 insertions, 36 deletions
diff --git a/docs/source/auto_examples/plot_OT_2D_samples.rst b/docs/source/auto_examples/plot_OT_2D_samples.rst
index e05e591..c472c6a 100644
--- a/docs/source/auto_examples/plot_OT_2D_samples.rst
+++ b/docs/source/auto_examples/plot_OT_2D_samples.rst
@@ -7,7 +7,6 @@
2D Optimal transport between empirical distributions
====================================================
-@author: rflamary
@@ -46,69 +45,64 @@
:scale: 47
-.. rst-class:: sphx-glr-script-out
- Out::
-
- ('Warning: numerical errors at iteration', 0)
-
-
-
-
-|
.. code-block:: python
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
+
import numpy as np
import matplotlib.pylab as pl
import ot
#%% parameters and data generation
- n=50 # nb samples
+ n = 50 # nb samples
- mu_s=np.array([0,0])
- cov_s=np.array([[1,0],[0,1]])
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
- mu_t=np.array([4,4])
- cov_t=np.array([[1,-.8],[-.8,1]])
+ mu_t = np.array([4, 4])
+ cov_t = np.array([[1, -.8], [-.8, 1]])
- xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)
- xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)
+ xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
+ xt = ot.datasets.get_2D_samples_gauss(n, mu_t, cov_t)
- a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
+ a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
# loss matrix
- M=ot.dist(xs,xt)
- M/=M.max()
+ M = ot.dist(xs, xt)
+ M /= M.max()
#%% plot samples
pl.figure(1)
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
- pl.title('Source and traget distributions')
+ pl.title('Source and target distributions')
pl.figure(2)
- pl.imshow(M,interpolation='nearest')
+ pl.imshow(M, interpolation='nearest')
pl.title('Cost matrix M')
#%% EMD
- G0=ot.emd(a,b,M)
+ G0 = ot.emd(a, b, M)
pl.figure(3)
- pl.imshow(G0,interpolation='nearest')
+ pl.imshow(G0, interpolation='nearest')
pl.title('OT matrix G0')
pl.figure(4)
- ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+ ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('OT matrix with samples')
@@ -116,22 +110,24 @@
#%% sinkhorn
# reg term
- lambd=5e-4
+ lambd = 1e-3
- Gs=ot.sinkhorn(a,b,M,lambd)
+ Gs = ot.sinkhorn(a, b, M, lambd)
pl.figure(5)
- pl.imshow(Gs,interpolation='nearest')
+ pl.imshow(Gs, interpolation='nearest')
pl.title('OT matrix sinkhorn')
pl.figure(6)
- ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+ ot.plot.plot2D_samples_mat(xs, xt, Gs, color=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('OT matrix Sinkhorn with samples')
-**Total running time of the script:** ( 0 minutes 0.623 seconds)
+ pl.show()
+
+**Total running time of the script:** ( 0 minutes 2.908 seconds)