summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/plot_OT_2D_samples.py4
-rw-r--r--examples/plot_compute_emd.py9
-rw-r--r--ot/bregman.py68
3 files changed, 58 insertions, 23 deletions
diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py
index 3b95083..edfb781 100644
--- a/examples/plot_OT_2D_samples.py
+++ b/examples/plot_OT_2D_samples.py
@@ -13,7 +13,7 @@ import ot
#%% parameters and data generation
-n=2 # nb samples
+n=50 # nb samples
mu_s=np.array([0,0])
cov_s=np.array([[1,0],[0,1]])
@@ -62,7 +62,7 @@ pl.title('OT matrix with samples')
#%% sinkhorn
# reg term
-lambd=5e-3
+lambd=5e-4
Gs=ot.sinkhorn(a,b,M,lambd)
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py
index 226bc97..08de6ee 100644
--- a/examples/plot_compute_emd.py
+++ b/examples/plot_compute_emd.py
@@ -16,7 +16,7 @@ from ot.datasets import get_1D_gauss as gauss
#%% parameters
n=100 # nb bins
-n_target=10 # nb target distributions
+n_target=50 # nb target distributions
# bin positions
@@ -61,14 +61,15 @@ pl.legend()
#%%
reg=1e-2
-d_sinkhorn=ot.sinkhorn(a,B,M,reg)
+d_sinkhorn=ot.sinkhorn(a,B,M,reg,method='sinkhorn_stabilized')
+d_sinkhorn0=ot.sinkhorn(a,B,M,reg)
d_sinkhorn2=ot.sinkhorn(a,B,M2,reg)
pl.figure(2)
pl.clf()
pl.plot(d_emd,label='Euclidean EMD')
pl.plot(d_emd2,label='Squared Euclidean EMD')
-pl.plot(d_sinkhorn,label='Euclidean Sinkhorn')
-pl.plot(d_emd2,label='Squared Euclidean Sinkhorn')
+pl.plot(d_sinkhorn,'+',label='Euclidean Sinkhorn')
+pl.plot(d_sinkhorn2,'+',label='Squared Euclidean Sinkhorn')
pl.title('EMD distances')
pl.legend() \ No newline at end of file
diff --git a/ot/bregman.py b/ot/bregman.py
index 3a9b15f..e847f24 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -32,8 +32,9 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
----------
a : np.ndarray (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
- samples in the target domain
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
M : np.ndarray (ns,nt)
loss matrix
reg : float
@@ -134,8 +135,9 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
----------
a : np.ndarray (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
- samples in the target domain
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
M : np.ndarray (ns,nt)
loss matrix
reg : float
@@ -369,11 +371,17 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
if len(b)==0:
b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
+ # test if multiple target
+ if len(b.shape)>1:
+ nbb=b.shape[1]
+ a=a[:,np.newaxis]
+ else:
+ nbb=0
+
# init data
na = len(a)
nb = len(b)
-
cpt = 0
if log:
log={'err':[]}
@@ -383,10 +391,11 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
alpha,beta=np.zeros(na),np.zeros(nb)
else:
alpha,beta=warmstart
- u,v = np.ones(na)/na,np.ones(nb)/nb
-
- #print(reg)
-
+
+ if nbb:
+ u,v = np.ones((na,nbb))/na,np.ones((nb,nbb))/nb
+ else:
+ u,v = np.ones(na)/na,np.ones(nb)/nb
def get_K(alpha,beta):
"""log space computation"""
@@ -405,22 +414,32 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
err=1
while loop:
+ # remove numerical problems and store them in K
if np.abs(u).max()>tau or np.abs(v).max()>tau:
- alpha,beta=alpha+reg*np.log(u),beta+reg*np.log(v)
- u,v = np.ones(na)/na,np.ones(nb)/nb
+ if nbb:
+ alpha,beta=alpha+reg*np.max(np.log(u),1),beta+reg*np.max(np.log(v))
+ else:
+ alpha,beta=alpha+reg*np.log(u),beta+reg*np.log(v)
+ if nbb:
+ u,v = np.ones((na,nbb))/na,np.ones((nb,nbb))/nb
+ else:
+ u,v = np.ones(na)/na,np.ones(nb)/nb
K=get_K(alpha,beta)
uprev = u
vprev = v
+
+ # sinkhorn update
v = b/np.dot(K.T,u)
u = a/np.dot(K,v)
-
-
if cpt%print_period==0:
# we can speed up the process by checking for the error only all the 10th iterations
- transp = get_Gamma(alpha,beta,u,v)
- err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
+ if nbb:
+ err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2)
+ else:
+ transp = get_Gamma(alpha,beta,u,v)
+ err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
if log:
log['err'].append(err)
@@ -448,6 +467,8 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
break
cpt = cpt +1
+
+
#print('err=',err,' cpt=',cpt)
if log:
log['logu']=alpha/reg+np.log(u)
@@ -455,9 +476,22 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
log['alpha']=alpha+reg*np.log(u)
log['beta']=beta+reg*np.log(v)
log['warmstart']=(log['alpha'],log['beta'])
- return get_Gamma(alpha,beta,u,v),log
+ if nbb:
+ res=np.zeros((nbb))
+ for i in range(nbb):
+ res[i]=np.sum(get_Gamma(alpha,beta,u[:,i],v[:,i])*M)
+ return res,log
+
+ else:
+ return get_Gamma(alpha,beta,u,v),log
else:
- return get_Gamma(alpha,beta,u,v)
+ if nbb:
+ res=np.zeros((nbb))
+ for i in range(nbb):
+ res[i]=np.sum(get_Gamma(alpha,beta,u[:,i],v[:,i])*M)
+ return res
+ else:
+ return get_Gamma(alpha,beta,u,v)
def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False,**kwargs):
"""