summaryrefslogtreecommitdiff
path: root/examples/demo_OT_2D_samples.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-27 17:29:36 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-27 17:29:36 +0200
commitd79fee64d4d8d318375f0ccba25bf71ae37ed1ea (patch)
tree8e2877034839cdd4d22109a1d4d3716d7e1dfe84 /examples/demo_OT_2D_samples.py
parent1c554920fbcbf15338a1cc3e50c7f27e0cb89597 (diff)
new notebook
Diffstat (limited to 'examples/demo_OT_2D_samples.py')
-rw-r--r--examples/demo_OT_2D_samples.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/examples/demo_OT_2D_samples.py b/examples/demo_OT_2D_samples.py
index d7de2c7..5fc214f 100644
--- a/examples/demo_OT_2D_samples.py
+++ b/examples/demo_OT_2D_samples.py
@@ -13,7 +13,7 @@ import ot
#%% parameters
-n=20 # nb bins
+n=20 # nb samples
mu_s=np.array([0,0])
cov_s=np.array([[1,0],[0,1]])
@@ -28,7 +28,7 @@ a,b = ot.unif(n),ot.unif(n)
# loss matrix
M=ot.dist(xs,xt)
-#M/=M.max()
+M/=M.max()
#%% plot samples
@@ -50,32 +50,33 @@ G0=ot.emd(a,b,M)
pl.figure(3)
pl.imshow(G0,interpolation='nearest')
-pl.title('Cost matrix M')
+pl.title('OT matrix G0')
pl.figure(4)
ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
pl.legend(loc=0)
-pl.title('OT matrix')
+pl.title('OT matrix with samples')
#%% sinkhorn
-lambd=1e-1
-Gs=ot.sinkhorn(a,b,M,lambd)
+# reg term
+lambd=5e-3
+Gs=ot.sinkhorn(a,b,M,lambd)
pl.figure(5)
pl.imshow(Gs,interpolation='nearest')
-pl.title('Cost matrix M')
+pl.title('OT matrix sinkhorn')
pl.figure(6)
ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
pl.legend(loc=0)
-pl.title('OT matrix Sinkhorn')
+pl.title('OT matrix Sinkhorn with samples')
#
#pl.figure(3)