summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py54
1 files changed, 39 insertions, 15 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 6b3c68b..3a9b15f 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -7,7 +7,7 @@ import numpy as np
def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
u"""
- Solve the entropic regularization optimal transport problem
+ Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
@@ -107,12 +107,9 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
return sink()
-
-
-
def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
"""
- Solve the entropic regularization optimal transport problem
+ Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
@@ -188,22 +185,35 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
a=np.asarray(a,dtype=np.float64)
b=np.asarray(b,dtype=np.float64)
M=np.asarray(M,dtype=np.float64)
+
if len(a)==0:
a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
if len(b)==0:
b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
+
# init data
Nini = len(a)
Nfin = len(b)
+
+ if len(b.shape)>1:
+ nbb=b.shape[1]
+ else:
+ nbb=0
+
if log:
log={'err':[]}
# we assume that no distances are null except those of the diagonal of distances
- u = np.ones(Nini)/Nini
- v = np.ones(Nfin)/Nfin
+ if nbb:
+ u = np.ones((Nini,nbb))/Nini
+ v = np.ones((Nfin,nbb))/Nfin
+ else:
+ u = np.ones(Nini)/Nini
+ v = np.ones(Nfin)/Nfin
+
#print(reg)
@@ -231,8 +241,11 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
break
if cpt%10==0:
# we can speed up the process by checking for the error only all the 10th iterations
- transp = u.reshape(-1, 1) * (K * v)
- err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
+ if nbb:
+ err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2)
+ else:
+ transp = u.reshape(-1, 1) * (K * v)
+ err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
if log:
log['err'].append(err)
@@ -244,12 +257,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
if log:
log['u']=u
log['v']=v
-
- #print('err=',err,' cpt=',cpt)
- if log:
- return u.reshape((-1,1))*K*v.reshape((1,-1)),log
- else:
- return u.reshape((-1,1))*K*v.reshape((1,-1))
+
+ if nbb: #return only loss
+ res=np.zeros((nbb))
+ for i in range(nbb):
+ res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M)
+ if log:
+ return res,log
+ else:
+ return res
+
+ else: # return OT matrix
+
+ if log:
+ return u.reshape((-1,1))*K*v.reshape((1,-1)),log
+ else:
+ return u.reshape((-1,1))*K*v.reshape((1,-1))
+
def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs):
"""