From a7bed093f91922e18fa5902c4d1d63b9712d5794 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Tue, 13 Jun 2017 15:14:45 +0200 Subject: implement paralell sinkhorn --- examples/plot_OT_1D.py | 2 +- examples/plot_compute_emd.py | 31 ++++++++++++++++++++++++------- 2 files changed, 25 insertions(+), 8 deletions(-) (limited to 'examples') diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py index e5719eb..6661aa3 100644 --- a/examples/plot_OT_1D.py +++ b/examples/plot_OT_1D.py @@ -50,7 +50,7 @@ ot.plot.plot1D_mat(a,b,G0,'OT matrix G0') #%% Sinkhorn lambd=1e-3 -Gs=ot.sinkhorn(a,b,M,lambd) +Gs=ot.sinkhorn(a,b,M,lambd,verbose=True) pl.figure(4) ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn') diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py index 87b39a6..226bc97 100644 --- a/examples/plot_compute_emd.py +++ b/examples/plot_compute_emd.py @@ -32,10 +32,11 @@ B=np.zeros((n,n_target)) for i,m in enumerate(lst_m): B[:,i]=gauss(n,m=m,s=5) -# loss matrix +# loss matrix and normalization M=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean') +M/=M.max() M2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean') - +M2/=M2.max() #%% plot the distributions pl.figure(1) @@ -46,12 +47,28 @@ pl.subplot(2,1,2) pl.plot(x,B,label='Target distributions') pl.title('Target distributions') -#%% plot distributions and loss matrix +#%% Compute and plot distributions and loss matrix + +d_emd=ot.emd2(a,B,M) # direct computation of EMD +d_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3 + -emd=ot.emd2(a,B,M) -emd2=ot.emd2(a,B,M2) pl.figure(2) -pl.plot(emd,label='Euclidean loss') -pl.plot(emd,label='Squared Euclidean loss') +pl.plot(d_emd,label='Euclidean EMD') +pl.plot(d_emd2,label='Squared Euclidean EMD') +pl.title('EMD distances') pl.legend() +#%% +reg=1e-2 +d_sinkhorn=ot.sinkhorn(a,B,M,reg) +d_sinkhorn2=ot.sinkhorn(a,B,M2,reg) + +pl.figure(2) +pl.clf() +pl.plot(d_emd,label='Euclidean EMD') +pl.plot(d_emd2,label='Squared Euclidean EMD') +pl.plot(d_sinkhorn,label='Euclidean Sinkhorn') +pl.plot(d_emd2,label='Squared Euclidean Sinkhorn') +pl.title('EMD distances') +pl.legend() \ No newline at end of file -- cgit v1.2.3