summaryrefslogtreecommitdiff
path: root/examples/demo_OT_2D_samples.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/demo_OT_2D_samples.py')
-rw-r--r--examples/demo_OT_2D_samples.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/examples/demo_OT_2D_samples.py b/examples/demo_OT_2D_samples.py
index d7de2c7..5fc214f 100644
--- a/examples/demo_OT_2D_samples.py
+++ b/examples/demo_OT_2D_samples.py
@@ -13,7 +13,7 @@ import ot
#%% parameters
-n=20 # nb bins
+n=20 # nb samples
mu_s=np.array([0,0])
cov_s=np.array([[1,0],[0,1]])
@@ -28,7 +28,7 @@ a,b = ot.unif(n),ot.unif(n)
# loss matrix
M=ot.dist(xs,xt)
-#M/=M.max()
+M/=M.max()
#%% plot samples
@@ -50,32 +50,33 @@ G0=ot.emd(a,b,M)
pl.figure(3)
pl.imshow(G0,interpolation='nearest')
-pl.title('Cost matrix M')
+pl.title('OT matrix G0')
pl.figure(4)
ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
pl.legend(loc=0)
-pl.title('OT matrix')
+pl.title('OT matrix with samples')
#%% sinkhorn
-lambd=1e-1
-Gs=ot.sinkhorn(a,b,M,lambd)
+# reg term
+lambd=5e-3
+Gs=ot.sinkhorn(a,b,M,lambd)
pl.figure(5)
pl.imshow(Gs,interpolation='nearest')
-pl.title('Cost matrix M')
+pl.title('OT matrix sinkhorn')
pl.figure(6)
ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
pl.legend(loc=0)
-pl.title('OT matrix Sinkhorn')
+pl.title('OT matrix Sinkhorn with samples')
#
#pl.figure(3)