diff options
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 54 |
1 files changed, 39 insertions, 15 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 6b3c68b..3a9b15f 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -7,7 +7,7 @@ import numpy as np def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): u""" - Solve the entropic regularization optimal transport problem + Solve the entropic regularization optimal transport problem and return the OT matrix The function solves the following optimization problem: @@ -107,12 +107,9 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver return sink() - - - def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): """ - Solve the entropic regularization optimal transport problem + Solve the entropic regularization optimal transport problem and return the OT matrix The function solves the following optimization problem: @@ -188,22 +185,35 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, a=np.asarray(a,dtype=np.float64) b=np.asarray(b,dtype=np.float64) M=np.asarray(M,dtype=np.float64) + if len(a)==0: a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0] if len(b)==0: b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1] + # init data Nini = len(a) Nfin = len(b) + + if len(b.shape)>1: + nbb=b.shape[1] + else: + nbb=0 + if log: log={'err':[]} # we assume that no distances are null except those of the diagonal of distances - u = np.ones(Nini)/Nini - v = np.ones(Nfin)/Nfin + if nbb: + u = np.ones((Nini,nbb))/Nini + v = np.ones((Nfin,nbb))/Nfin + else: + u = np.ones(Nini)/Nini + v = np.ones(Nfin)/Nfin + #print(reg) @@ -231,8 +241,11 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, break if cpt%10==0: # we can speed up the process by checking for the error only all the 10th iterations - transp = u.reshape(-1, 1) * (K * v) - err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 + if nbb: + err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2) + else: + transp = u.reshape(-1, 1) * (K * v) + err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 if log: log['err'].append(err) @@ -244,12 +257,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, if log: log['u']=u log['v']=v - - #print('err=',err,' cpt=',cpt) - if log: - return u.reshape((-1,1))*K*v.reshape((1,-1)),log - else: - return u.reshape((-1,1))*K*v.reshape((1,-1)) + + if nbb: #return only loss + res=np.zeros((nbb)) + for i in range(nbb): + res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M) + if log: + return res,log + else: + return res + + else: # return OT matrix + + if log: + return u.reshape((-1,1))*K*v.reshape((1,-1)),log + else: + return u.reshape((-1,1))*K*v.reshape((1,-1)) + def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs): """ |