diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-27 17:29:36 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-27 17:29:36 +0200 |
commit | d79fee64d4d8d318375f0ccba25bf71ae37ed1ea (patch) | |
tree | 8e2877034839cdd4d22109a1d4d3716d7e1dfe84 /examples/demo_OT_2D_samples.py | |
parent | 1c554920fbcbf15338a1cc3e50c7f27e0cb89597 (diff) |
new notebook
Diffstat (limited to 'examples/demo_OT_2D_samples.py')
-rw-r--r-- | examples/demo_OT_2D_samples.py | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/examples/demo_OT_2D_samples.py b/examples/demo_OT_2D_samples.py index d7de2c7..5fc214f 100644 --- a/examples/demo_OT_2D_samples.py +++ b/examples/demo_OT_2D_samples.py @@ -13,7 +13,7 @@ import ot #%% parameters -n=20 # nb bins +n=20 # nb samples mu_s=np.array([0,0]) cov_s=np.array([[1,0],[0,1]]) @@ -28,7 +28,7 @@ a,b = ot.unif(n),ot.unif(n) # loss matrix M=ot.dist(xs,xt) -#M/=M.max() +M/=M.max() #%% plot samples @@ -50,32 +50,33 @@ G0=ot.emd(a,b,M) pl.figure(3) pl.imshow(G0,interpolation='nearest') -pl.title('Cost matrix M') +pl.title('OT matrix G0') pl.figure(4) ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') pl.legend(loc=0) -pl.title('OT matrix') +pl.title('OT matrix with samples') #%% sinkhorn -lambd=1e-1 -Gs=ot.sinkhorn(a,b,M,lambd) +# reg term +lambd=5e-3 +Gs=ot.sinkhorn(a,b,M,lambd) pl.figure(5) pl.imshow(Gs,interpolation='nearest') -pl.title('Cost matrix M') +pl.title('OT matrix sinkhorn') pl.figure(6) ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1]) pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') pl.legend(loc=0) -pl.title('OT matrix Sinkhorn') +pl.title('OT matrix Sinkhorn with samples') # #pl.figure(3) |