From d79fee64d4d8d318375f0ccba25bf71ae37ed1ea Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Thu, 27 Oct 2016 17:29:36 +0200 Subject: new notebook --- examples/demo_OT_2D_samples.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) (limited to 'examples/demo_OT_2D_samples.py') diff --git a/examples/demo_OT_2D_samples.py b/examples/demo_OT_2D_samples.py index d7de2c7..5fc214f 100644 --- a/examples/demo_OT_2D_samples.py +++ b/examples/demo_OT_2D_samples.py @@ -13,7 +13,7 @@ import ot #%% parameters -n=20 # nb bins +n=20 # nb samples mu_s=np.array([0,0]) cov_s=np.array([[1,0],[0,1]]) @@ -28,7 +28,7 @@ a,b = ot.unif(n),ot.unif(n) # loss matrix M=ot.dist(xs,xt) -#M/=M.max() +M/=M.max() #%% plot samples @@ -50,32 +50,33 @@ G0=ot.emd(a,b,M) pl.figure(3) pl.imshow(G0,interpolation='nearest') -pl.title('Cost matrix M') +pl.title('OT matrix G0') pl.figure(4) ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') pl.legend(loc=0) -pl.title('OT matrix') +pl.title('OT matrix with samples') #%% sinkhorn -lambd=1e-1 -Gs=ot.sinkhorn(a,b,M,lambd) +# reg term +lambd=5e-3 +Gs=ot.sinkhorn(a,b,M,lambd) pl.figure(5) pl.imshow(Gs,interpolation='nearest') -pl.title('Cost matrix M') +pl.title('OT matrix sinkhorn') pl.figure(6) ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1]) pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') pl.legend(loc=0) -pl.title('OT matrix Sinkhorn') +pl.title('OT matrix Sinkhorn with samples') # #pl.figure(3) -- cgit v1.2.3