summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-06-13 15:14:45 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-06-13 15:14:45 +0200
commita7bed093f91922e18fa5902c4d1d63b9712d5794 (patch)
tree30ac6d07116ff74d682b82c9f866def1a44a2aaa
parent3af9b06b6b3c24eb02931cd5fbf798034dd6b8a1 (diff)
implement paralell sinkhorn
-rw-r--r--examples/plot_OT_1D.py2
-rw-r--r--examples/plot_compute_emd.py31
-rw-r--r--ot/bregman.py54
3 files changed, 64 insertions, 23 deletions
diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py
index e5719eb..6661aa3 100644
--- a/examples/plot_OT_1D.py
+++ b/examples/plot_OT_1D.py
@@ -50,7 +50,7 @@ ot.plot.plot1D_mat(a,b,G0,'OT matrix G0')
#%% Sinkhorn
lambd=1e-3
-Gs=ot.sinkhorn(a,b,M,lambd)
+Gs=ot.sinkhorn(a,b,M,lambd,verbose=True)
pl.figure(4)
ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py
index 87b39a6..226bc97 100644
--- a/examples/plot_compute_emd.py
+++ b/examples/plot_compute_emd.py
@@ -32,10 +32,11 @@ B=np.zeros((n,n_target))
for i,m in enumerate(lst_m):
B[:,i]=gauss(n,m=m,s=5)
-# loss matrix
+# loss matrix and normalization
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean')
+M/=M.max()
M2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean')
-
+M2/=M2.max()
#%% plot the distributions
pl.figure(1)
@@ -46,12 +47,28 @@ pl.subplot(2,1,2)
pl.plot(x,B,label='Target distributions')
pl.title('Target distributions')
-#%% plot distributions and loss matrix
+#%% Compute and plot distributions and loss matrix
+
+d_emd=ot.emd2(a,B,M) # direct computation of EMD
+d_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3
+
-emd=ot.emd2(a,B,M)
-emd2=ot.emd2(a,B,M2)
pl.figure(2)
-pl.plot(emd,label='Euclidean loss')
-pl.plot(emd,label='Squared Euclidean loss')
+pl.plot(d_emd,label='Euclidean EMD')
+pl.plot(d_emd2,label='Squared Euclidean EMD')
+pl.title('EMD distances')
pl.legend()
+#%%
+reg=1e-2
+d_sinkhorn=ot.sinkhorn(a,B,M,reg)
+d_sinkhorn2=ot.sinkhorn(a,B,M2,reg)
+
+pl.figure(2)
+pl.clf()
+pl.plot(d_emd,label='Euclidean EMD')
+pl.plot(d_emd2,label='Squared Euclidean EMD')
+pl.plot(d_sinkhorn,label='Euclidean Sinkhorn')
+pl.plot(d_emd2,label='Squared Euclidean Sinkhorn')
+pl.title('EMD distances')
+pl.legend() \ No newline at end of file
diff --git a/ot/bregman.py b/ot/bregman.py
index 6b3c68b..3a9b15f 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -7,7 +7,7 @@ import numpy as np
def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
u"""
- Solve the entropic regularization optimal transport problem
+ Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
@@ -107,12 +107,9 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
return sink()
-
-
-
def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
"""
- Solve the entropic regularization optimal transport problem
+ Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
@@ -188,22 +185,35 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
a=np.asarray(a,dtype=np.float64)
b=np.asarray(b,dtype=np.float64)
M=np.asarray(M,dtype=np.float64)
+
if len(a)==0:
a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
if len(b)==0:
b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
+
# init data
Nini = len(a)
Nfin = len(b)
+
+ if len(b.shape)>1:
+ nbb=b.shape[1]
+ else:
+ nbb=0
+
if log:
log={'err':[]}
# we assume that no distances are null except those of the diagonal of distances
- u = np.ones(Nini)/Nini
- v = np.ones(Nfin)/Nfin
+ if nbb:
+ u = np.ones((Nini,nbb))/Nini
+ v = np.ones((Nfin,nbb))/Nfin
+ else:
+ u = np.ones(Nini)/Nini
+ v = np.ones(Nfin)/Nfin
+
#print(reg)
@@ -231,8 +241,11 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
break
if cpt%10==0:
# we can speed up the process by checking for the error only all the 10th iterations
- transp = u.reshape(-1, 1) * (K * v)
- err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
+ if nbb:
+ err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2)
+ else:
+ transp = u.reshape(-1, 1) * (K * v)
+ err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
if log:
log['err'].append(err)
@@ -244,12 +257,23 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
if log:
log['u']=u
log['v']=v
-
- #print('err=',err,' cpt=',cpt)
- if log:
- return u.reshape((-1,1))*K*v.reshape((1,-1)),log
- else:
- return u.reshape((-1,1))*K*v.reshape((1,-1))
+
+ if nbb: #return only loss
+ res=np.zeros((nbb))
+ for i in range(nbb):
+ res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M)
+ if log:
+ return res,log
+ else:
+ return res
+
+ else: # return OT matrix
+
+ if log:
+ return u.reshape((-1,1))*K*v.reshape((1,-1)),log
+ else:
+ return u.reshape((-1,1))*K*v.reshape((1,-1))
+
def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs):
"""