From 38e96f88eb520b9fa8333686565b082d2921e131 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Tue, 13 Jun 2017 15:50:11 +0200 Subject: implement paralell sinkhorn stabilized --- examples/plot_OT_2D_samples.py | 4 +-- examples/plot_compute_emd.py | 9 +++--- ot/bregman.py | 68 +++++++++++++++++++++++++++++++----------- 3 files changed, 58 insertions(+), 23 deletions(-) diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index 3b95083..edfb781 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -13,7 +13,7 @@ import ot #%% parameters and data generation -n=2 # nb samples +n=50 # nb samples mu_s=np.array([0,0]) cov_s=np.array([[1,0],[0,1]]) @@ -62,7 +62,7 @@ pl.title('OT matrix with samples') #%% sinkhorn # reg term -lambd=5e-3 +lambd=5e-4 Gs=ot.sinkhorn(a,b,M,lambd) diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py index 226bc97..08de6ee 100644 --- a/examples/plot_compute_emd.py +++ b/examples/plot_compute_emd.py @@ -16,7 +16,7 @@ from ot.datasets import get_1D_gauss as gauss #%% parameters n=100 # nb bins -n_target=10 # nb target distributions +n_target=50 # nb target distributions # bin positions @@ -61,14 +61,15 @@ pl.legend() #%% reg=1e-2 -d_sinkhorn=ot.sinkhorn(a,B,M,reg) +d_sinkhorn=ot.sinkhorn(a,B,M,reg,method='sinkhorn_stabilized') +d_sinkhorn0=ot.sinkhorn(a,B,M,reg) d_sinkhorn2=ot.sinkhorn(a,B,M2,reg) pl.figure(2) pl.clf() pl.plot(d_emd,label='Euclidean EMD') pl.plot(d_emd2,label='Squared Euclidean EMD') -pl.plot(d_sinkhorn,label='Euclidean Sinkhorn') -pl.plot(d_emd2,label='Squared Euclidean Sinkhorn') +pl.plot(d_sinkhorn,'+',label='Euclidean Sinkhorn') +pl.plot(d_sinkhorn2,'+',label='Squared Euclidean Sinkhorn') pl.title('EMD distances') pl.legend() \ No newline at end of file diff --git a/ot/bregman.py b/ot/bregman.py index 3a9b15f..e847f24 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -32,8 +32,9 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver ---------- a : np.ndarray (ns,) samples weights in the source domain - b : np.ndarray (nt,) - samples in the target domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) M : np.ndarray (ns,nt) loss matrix reg : float @@ -134,8 +135,9 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, ---------- a : np.ndarray (ns,) samples weights in the source domain - b : np.ndarray (nt,) - samples in the target domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) M : np.ndarray (ns,nt) loss matrix reg : float @@ -369,11 +371,17 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war if len(b)==0: b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1] + # test if multiple target + if len(b.shape)>1: + nbb=b.shape[1] + a=a[:,np.newaxis] + else: + nbb=0 + # init data na = len(a) nb = len(b) - cpt = 0 if log: log={'err':[]} @@ -383,10 +391,11 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war alpha,beta=np.zeros(na),np.zeros(nb) else: alpha,beta=warmstart - u,v = np.ones(na)/na,np.ones(nb)/nb - - #print(reg) - + + if nbb: + u,v = np.ones((na,nbb))/na,np.ones((nb,nbb))/nb + else: + u,v = np.ones(na)/na,np.ones(nb)/nb def get_K(alpha,beta): """log space computation""" @@ -405,22 +414,32 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war err=1 while loop: + # remove numerical problems and store them in K if np.abs(u).max()>tau or np.abs(v).max()>tau: - alpha,beta=alpha+reg*np.log(u),beta+reg*np.log(v) - u,v = np.ones(na)/na,np.ones(nb)/nb + if nbb: + alpha,beta=alpha+reg*np.max(np.log(u),1),beta+reg*np.max(np.log(v)) + else: + alpha,beta=alpha+reg*np.log(u),beta+reg*np.log(v) + if nbb: + u,v = np.ones((na,nbb))/na,np.ones((nb,nbb))/nb + else: + u,v = np.ones(na)/na,np.ones(nb)/nb K=get_K(alpha,beta) uprev = u vprev = v + + # sinkhorn update v = b/np.dot(K.T,u) u = a/np.dot(K,v) - - if cpt%print_period==0: # we can speed up the process by checking for the error only all the 10th iterations - transp = get_Gamma(alpha,beta,u,v) - err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 + if nbb: + err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2) + else: + transp = get_Gamma(alpha,beta,u,v) + err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 if log: log['err'].append(err) @@ -448,6 +467,8 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war break cpt = cpt +1 + + #print('err=',err,' cpt=',cpt) if log: log['logu']=alpha/reg+np.log(u) @@ -455,9 +476,22 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war log['alpha']=alpha+reg*np.log(u) log['beta']=beta+reg*np.log(v) log['warmstart']=(log['alpha'],log['beta']) - return get_Gamma(alpha,beta,u,v),log + if nbb: + res=np.zeros((nbb)) + for i in range(nbb): + res[i]=np.sum(get_Gamma(alpha,beta,u[:,i],v[:,i])*M) + return res,log + + else: + return get_Gamma(alpha,beta,u,v),log else: - return get_Gamma(alpha,beta,u,v) + if nbb: + res=np.zeros((nbb)) + for i in range(nbb): + res[i]=np.sum(get_Gamma(alpha,beta,u[:,i],v[:,i])*M) + return res + else: + return get_Gamma(alpha,beta,u,v) def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False,**kwargs): """ -- cgit v1.2.3