diff options
-rw-r--r-- | examples/plot_OT_1D.py | 2 | ||||
-rw-r--r-- | examples/plot_compute_emd.py | 31 | ||||
-rw-r--r-- | ot/bregman.py | 54 |
3 files changed, 64 insertions, 23 deletions
diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py index e5719eb..6661aa3 100644 --- a/examples/plot_OT_1D.py +++ b/examples/plot_OT_1D.py @@ -50,7 +50,7 @@ ot.plot.plot1D_mat(a,b,G0,'OT matrix G0') #%% Sinkhorn lambd=1e-3 -Gs=ot.sinkhorn(a,b,M,lambd) +Gs=ot.sinkhorn(a,b,M,lambd,verbose=True) pl.figure(4) ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn') diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py index 87b39a6..226bc97 100644 --- a/examples/plot_compute_emd.py +++ b/examples/plot_compute_emd.py @@ -32,10 +32,11 @@ B=np.zeros((n,n_target)) for i,m in enumerate(lst_m): B[:,i]=gauss(n,m=m,s=5) -# loss matrix +# loss matrix and normalization M=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean') +M/=M.max() M2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean') - +M2/=M2.max() #%% plot the distributions pl.figure(1) @@ -46,12 +47,28 @@ pl.subplot(2,1,2) pl.plot(x,B,label='Target distributions') pl.title('Target distributions') -#%% plot distributions and loss matrix +#%% Compute and plot distributions and loss matrix + +d_emd=ot.emd2(a,B,M) # direct computation of EMD +d_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3 + -emd=ot.emd2(a,B,M) -emd2=ot.emd2(a,B,M2) pl.figure(2) -pl.plot(emd,label='Euclidean loss') -pl.plot(emd,label='Squared Euclidean loss') +pl.plot(d_emd,label='Euclidean EMD') +pl.plot(d_emd2,label='Squared Euclidean EMD') +pl.title('EMD distances') pl.legend() +#%% +reg=1e-2 +d_sinkhorn=ot.sinkhorn(a,B,M,reg) +d_sinkhorn2=ot.sinkhorn(a,B,M2,reg) + +pl.figure(2) +pl.clf() +pl.plot(d_emd,label='Euclidean EMD') +pl.plot(d_emd2,label='Squared Euclidean EMD') +pl.plot(d_sinkhorn,label='Euclidean Sinkhorn') +pl.plot(d_emd2,label='Squared Euclidean Sinkhorn') +pl.title('EMD distances') +pl.legend()
\ No newline at end of file diff --git a/ot/bregman.py b/ot/bregman.py index 6b3c68b..3a9b15f 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -7,7 +7,7 @@ import numpy as np def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): u""" - Solve the entropic regularization optimal transport problem + Solve the entropic regularization optimal transport problem and return the OT matrix The function solves the following optimization problem: @@ -107,12 +107,9 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver return sink() - - - def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): """ - Solve the entropic regularization optimal transport problem + Solve the entropic regularization optimal transport problem and return the OT matrix The function solves the following optimization problem: @@ -188,22 +185,35 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, a=np.asarray(a,dtype=np.float64) b=np.asarray(b,dtype=np.float64) M=np.asarray(M,dtype=np.float64) + if len(a)==0: a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0] if len(b)==0: b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1] + # init data Nini = len(a) Nfin = len(b) + + if len(b.shape)>1: + nbb=b.shape[1] + else: + nbb=0 + if log: log={'err':[]} # we assume that no distances are null except those of the diagonal of distances - u = np.ones(Nini)/Nini - v = np.ones(Nfin)/Nfin + if nbb: + u = np.ones((Nini,nbb))/Nini + v = np.ones((Nfin,nbb))/Nfin + else: + u = np.ones(Nini)/Nini + v = np.ones(Nfin)/Nfin + #print(reg) @@ -231,8 +241,11 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, break if cpt%10==0: # we can speed up the process by checking for the error only all the 10th iterations - transp = u.reshape(-1, 1) * (K * v) - err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 + if nbb: + err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2) + else: + transp = u.reshape(-1, 1) * (K * v) + err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 if log: log['err'].append(err) @@ -244,12 +257,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, if log: log['u']=u log['v']=v - - #print('err=',err,' cpt=',cpt) - if log: - return u.reshape((-1,1))*K*v.reshape((1,-1)),log - else: - return u.reshape((-1,1))*K*v.reshape((1,-1)) + + if nbb: #return only loss + res=np.zeros((nbb)) + for i in range(nbb): + res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M) + if log: + return res,log + else: + return res + + else: # return OT matrix + + if log: + return u.reshape((-1,1))*K*v.reshape((1,-1)),log + else: + return u.reshape((-1,1))*K*v.reshape((1,-1)) + def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs): """ |