diff options
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 68 |
1 files changed, 51 insertions, 17 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 3a9b15f..e847f24 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -32,8 +32,9 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver ---------- a : np.ndarray (ns,) samples weights in the source domain - b : np.ndarray (nt,) - samples in the target domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) M : np.ndarray (ns,nt) loss matrix reg : float @@ -134,8 +135,9 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, ---------- a : np.ndarray (ns,) samples weights in the source domain - b : np.ndarray (nt,) - samples in the target domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) M : np.ndarray (ns,nt) loss matrix reg : float @@ -369,11 +371,17 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war if len(b)==0: b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1] + # test if multiple target + if len(b.shape)>1: + nbb=b.shape[1] + a=a[:,np.newaxis] + else: + nbb=0 + # init data na = len(a) nb = len(b) - cpt = 0 if log: log={'err':[]} @@ -383,10 +391,11 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war alpha,beta=np.zeros(na),np.zeros(nb) else: alpha,beta=warmstart - u,v = np.ones(na)/na,np.ones(nb)/nb - - #print(reg) - + + if nbb: + u,v = np.ones((na,nbb))/na,np.ones((nb,nbb))/nb + else: + u,v = np.ones(na)/na,np.ones(nb)/nb def get_K(alpha,beta): """log space computation""" @@ -405,22 +414,32 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war err=1 while loop: + # remove numerical problems and store them in K if np.abs(u).max()>tau or np.abs(v).max()>tau: - alpha,beta=alpha+reg*np.log(u),beta+reg*np.log(v) - u,v = np.ones(na)/na,np.ones(nb)/nb + if nbb: + alpha,beta=alpha+reg*np.max(np.log(u),1),beta+reg*np.max(np.log(v)) + else: + alpha,beta=alpha+reg*np.log(u),beta+reg*np.log(v) + if nbb: + u,v = np.ones((na,nbb))/na,np.ones((nb,nbb))/nb + else: + u,v = np.ones(na)/na,np.ones(nb)/nb K=get_K(alpha,beta) uprev = u vprev = v + + # sinkhorn update v = b/np.dot(K.T,u) u = a/np.dot(K,v) - - if cpt%print_period==0: # we can speed up the process by checking for the error only all the 10th iterations - transp = get_Gamma(alpha,beta,u,v) - err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 + if nbb: + err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2) + else: + transp = get_Gamma(alpha,beta,u,v) + err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 if log: log['err'].append(err) @@ -448,6 +467,8 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war break cpt = cpt +1 + + #print('err=',err,' cpt=',cpt) if log: log['logu']=alpha/reg+np.log(u) @@ -455,9 +476,22 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war log['alpha']=alpha+reg*np.log(u) log['beta']=beta+reg*np.log(v) log['warmstart']=(log['alpha'],log['beta']) - return get_Gamma(alpha,beta,u,v),log + if nbb: + res=np.zeros((nbb)) + for i in range(nbb): + res[i]=np.sum(get_Gamma(alpha,beta,u[:,i],v[:,i])*M) + return res,log + + else: + return get_Gamma(alpha,beta,u,v),log else: - return get_Gamma(alpha,beta,u,v) + if nbb: + res=np.zeros((nbb)) + for i in range(nbb): + res[i]=np.sum(get_Gamma(alpha,beta,u[:,i],v[:,i])*M) + return res + else: + return get_Gamma(alpha,beta,u,v) def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False,**kwargs): """ |