summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py512
1 files changed, 269 insertions, 243 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 0d68602..d63c51d 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -3,9 +3,15 @@
Bregman projections for regularized OT
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
import numpy as np
-def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
+
+def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
u"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -92,22 +98,29 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
"""
- if method.lower()=='sinkhorn':
- sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,**kwargs)
- elif method.lower()=='sinkhorn_stabilized':
- sink= lambda: sinkhorn_stabilized(a,b, M, reg,numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
- elif method.lower()=='sinkhorn_epsilon_scaling':
- sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ if method.lower() == 'sinkhorn':
+ def sink():
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ def sink():
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ elif method.lower() == 'sinkhorn_epsilon_scaling':
+ def sink():
+ return sinkhorn_epsilon_scaling(
+ a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
else:
print('Warning : unknown method using classic Sinkhorn Knopp')
- sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs)
+
+ def sink():
+ return sinkhorn_knopp(a, b, M, reg, **kwargs)
return sink()
-def sinkhorn2(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
+
+def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
u"""
Solve the entropic regularization optimal transport problem and return the loss
@@ -170,7 +183,7 @@ def sinkhorn2(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ve
>>> M=[[0.,1.],[1.,0.]]
>>> ot.sinkhorn2(a,b,M,1)
array([ 0.26894142])
-
+
References
@@ -194,27 +207,33 @@ def sinkhorn2(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ve
"""
- if method.lower()=='sinkhorn':
- sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,**kwargs)
- elif method.lower()=='sinkhorn_stabilized':
- sink= lambda: sinkhorn_stabilized(a,b, M, reg,numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
- elif method.lower()=='sinkhorn_epsilon_scaling':
- sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ if method.lower() == 'sinkhorn':
+ def sink():
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ def sink():
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ elif method.lower() == 'sinkhorn_epsilon_scaling':
+ def sink():
+ return sinkhorn_epsilon_scaling(
+ a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
else:
print('Warning : unknown method using classic Sinkhorn Knopp')
- sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs)
-
- b=np.asarray(b,dtype=np.float64)
- if len(b.shape)<2:
- b=b.reshape((-1,1))
+
+ def sink():
+ return sinkhorn_knopp(a, b, M, reg, **kwargs)
+
+ b = np.asarray(b, dtype=np.float64)
+ if len(b.shape) < 2:
+ b = b.reshape((-1, 1))
return sink()
-def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
+def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -290,100 +309,101 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False,
"""
- a=np.asarray(a,dtype=np.float64)
- b=np.asarray(b,dtype=np.float64)
- M=np.asarray(M,dtype=np.float64)
-
-
- if len(a)==0:
- a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
- if len(b)==0:
- b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+ if len(a) == 0:
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
# init data
Nini = len(a)
Nfin = len(b)
- if len(b.shape)>1:
- nbb=b.shape[1]
+ if len(b.shape) > 1:
+ nbb = b.shape[1]
else:
- nbb=0
-
+ nbb = 0
if log:
- log={'err':[]}
+ log = {'err': []}
- # we assume that no distances are null except those of the diagonal of distances
+ # we assume that no distances are null except those of the diagonal of
+ # distances
if nbb:
- u = np.ones((Nini,nbb))/Nini
- v = np.ones((Nfin,nbb))/Nfin
+ u = np.ones((Nini, nbb)) / Nini
+ v = np.ones((Nfin, nbb)) / Nfin
else:
- u = np.ones(Nini)/Nini
- v = np.ones(Nfin)/Nfin
+ u = np.ones(Nini) / Nini
+ v = np.ones(Nfin) / Nfin
+ # print(reg)
- #print(reg)
+ K = np.exp(-M / reg)
+ # print(np.min(K))
- K = np.exp(-M/reg)
- #print(np.min(K))
-
- Kp = (1/a).reshape(-1, 1) * K
+ Kp = (1 / a).reshape(-1, 1) * K
cpt = 0
- err=1
- while (err>stopThr and cpt<numItermax):
+ err = 1
+ while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
KtransposeU = np.dot(K.T, u)
v = np.divide(b, KtransposeU)
- u = 1./np.dot(Kp,v)
+ u = 1. / np.dot(Kp, v)
- if (np.any(KtransposeU==0) or
- np.any(np.isnan(u)) or np.any(np.isnan(v)) or
- np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (np.any(KtransposeU == 0) or
+ np.any(np.isnan(u)) or np.any(np.isnan(v)) or
+ np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
print('Warning: numerical errors at iteration', cpt)
u = uprev
v = vprev
break
- if cpt%10==0:
- # we can speed up the process by checking for the error only all the 10th iterations
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
if nbb:
- err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2)
+ err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
+ np.sum((v - vprev)**2) / np.sum((v)**2)
else:
transp = u.reshape(-1, 1) * (K * v)
- err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
+ err = np.linalg.norm((np.sum(transp, axis=0) - b))**2
if log:
log['err'].append(err)
if verbose:
- if cpt%200 ==0:
- print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
- print('{:5d}|{:8e}|'.format(cpt,err))
- cpt = cpt +1
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+ cpt = cpt + 1
if log:
- log['u']=u
- log['v']=v
+ log['u'] = u
+ log['v'] = v
- if nbb: #return only loss
- res=np.zeros((nbb))
+ if nbb: # return only loss
+ res = np.zeros((nbb))
for i in range(nbb):
- res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M)
+ res[i] = np.sum(
+ u[:, i].reshape((-1, 1)) * K * v[:, i].reshape((1, -1)) * M)
if log:
- return res,log
+ return res, log
else:
return res
- else: # return OT matrix
+ else: # return OT matrix
if log:
- return u.reshape((-1,1))*K*v.reshape((1,-1)),log
+ return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log
else:
- return u.reshape((-1,1))*K*v.reshape((1,-1))
+ return u.reshape((-1, 1)) * K * v.reshape((1, -1))
-def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs):
+def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
"""
Solve the entropic regularization OT problem with log stabilization
@@ -468,21 +488,21 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
"""
- a=np.asarray(a,dtype=np.float64)
- b=np.asarray(b,dtype=np.float64)
- M=np.asarray(M,dtype=np.float64)
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
- if len(a)==0:
- a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
- if len(b)==0:
- b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
+ if len(a) == 0:
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
# test if multiple target
- if len(b.shape)>1:
- nbb=b.shape[1]
- a=a[:,np.newaxis]
+ if len(b.shape) > 1:
+ nbb = b.shape[1]
+ a = a[:, np.newaxis]
else:
- nbb=0
+ nbb = 0
# init data
na = len(a)
@@ -490,81 +510,80 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
cpt = 0
if log:
- log={'err':[]}
+ log = {'err': []}
- # we assume that no distances are null except those of the diagonal of distances
+ # we assume that no distances are null except those of the diagonal of
+ # distances
if warmstart is None:
- alpha,beta=np.zeros(na),np.zeros(nb)
+ alpha, beta = np.zeros(na), np.zeros(nb)
else:
- alpha,beta=warmstart
+ alpha, beta = warmstart
if nbb:
- u,v = np.ones((na,nbb))/na,np.ones((nb,nbb))/nb
+ u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb
else:
- u,v = np.ones(na)/na,np.ones(nb)/nb
+ u, v = np.ones(na) / na, np.ones(nb) / nb
- def get_K(alpha,beta):
+ def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg)
+ return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg)
- def get_Gamma(alpha,beta,u,v):
+ def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
- return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg+np.log(u.reshape((na,1)))+np.log(v.reshape((1,nb))))
+ return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
- #print(np.min(K))
+ # print(np.min(K))
- K=get_K(alpha,beta)
+ K = get_K(alpha, beta)
transp = K
- loop=1
+ loop = 1
cpt = 0
- err=1
+ err = 1
while loop:
-
-
uprev = u
vprev = v
# sinkhorn update
- v = b/(np.dot(K.T,u)+1e-16)
- u = a/(np.dot(K,v)+1e-16)
-
+ v = b / (np.dot(K.T, u) + 1e-16)
+ u = a / (np.dot(K, v) + 1e-16)
# remove numerical problems and store them in K
- if np.abs(u).max()>tau or np.abs(v).max()>tau:
+ if np.abs(u).max() > tau or np.abs(v).max() > tau:
if nbb:
- alpha,beta=alpha+reg*np.max(np.log(u),1),beta+reg*np.max(np.log(v))
+ alpha, beta = alpha + reg * \
+ np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
else:
- alpha,beta=alpha+reg*np.log(u),beta+reg*np.log(v)
+ alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v)
if nbb:
- u,v = np.ones((na,nbb))/na,np.ones((nb,nbb))/nb
+ u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb
else:
- u,v = np.ones(na)/na,np.ones(nb)/nb
- K=get_K(alpha,beta)
-
+ u, v = np.ones(na) / na, np.ones(nb) / nb
+ K = get_K(alpha, beta)
- if cpt%print_period==0:
- # we can speed up the process by checking for the error only all the 10th iterations
+ if cpt % print_period == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
if nbb:
- err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2)
+ err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
+ np.sum((v - vprev)**2) / np.sum((v)**2)
else:
- transp = get_Gamma(alpha,beta,u,v)
- err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
+ transp = get_Gamma(alpha, beta, u, v)
+ err = np.linalg.norm((np.sum(transp, axis=0) - b))**2
if log:
log['err'].append(err)
if verbose:
- if cpt%(print_period*20) ==0:
- print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
- print('{:5d}|{:8e}|'.format(cpt,err))
+ if cpt % (print_period * 20) == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+ if err <= stopThr:
+ loop = False
- if err<=stopThr:
- loop=False
-
- if cpt>=numItermax:
- loop=False
-
+ if cpt >= numItermax:
+ loop = False
if np.any(np.isnan(u)) or np.any(np.isnan(v)):
# we have reached the machine precision
@@ -574,34 +593,34 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war
v = vprev
break
- cpt = cpt +1
+ cpt = cpt + 1
-
- #print('err=',err,' cpt=',cpt)
+ # print('err=',err,' cpt=',cpt)
if log:
- log['logu']=alpha/reg+np.log(u)
- log['logv']=beta/reg+np.log(v)
- log['alpha']=alpha+reg*np.log(u)
- log['beta']=beta+reg*np.log(v)
- log['warmstart']=(log['alpha'],log['beta'])
+ log['logu'] = alpha / reg + np.log(u)
+ log['logv'] = beta / reg + np.log(v)
+ log['alpha'] = alpha + reg * np.log(u)
+ log['beta'] = beta + reg * np.log(v)
+ log['warmstart'] = (log['alpha'], log['beta'])
if nbb:
- res=np.zeros((nbb))
+ res = np.zeros((nbb))
for i in range(nbb):
- res[i]=np.sum(get_Gamma(alpha,beta,u[:,i],v[:,i])*M)
- return res,log
+ res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ return res, log
else:
- return get_Gamma(alpha,beta,u,v),log
+ return get_Gamma(alpha, beta, u, v), log
else:
if nbb:
- res=np.zeros((nbb))
+ res = np.zeros((nbb))
for i in range(nbb):
- res[i]=np.sum(get_Gamma(alpha,beta,u[:,i],v[:,i])*M)
+ res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
return res
else:
- return get_Gamma(alpha,beta,u,v)
+ return get_Gamma(alpha, beta, u, v)
+
-def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False,**kwargs):
+def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs):
"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
@@ -690,14 +709,14 @@ def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInn
"""
- a=np.asarray(a,dtype=np.float64)
- b=np.asarray(b,dtype=np.float64)
- M=np.asarray(M,dtype=np.float64)
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
- if len(a)==0:
- a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0]
- if len(b)==0:
- b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1]
+ if len(a) == 0:
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
# init data
na = len(a)
@@ -705,88 +724,94 @@ def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInn
# nrelative umerical precision with 64 bits
numItermin = 35
- numItermax=max(numItermin,numItermax) # ensure that last velue is exact
-
+ numItermax = max(numItermin, numItermax) # ensure that last velue is exact
cpt = 0
if log:
- log={'err':[]}
+ log = {'err': []}
- # we assume that no distances are null except those of the diagonal of distances
+ # we assume that no distances are null except those of the diagonal of
+ # distances
if warmstart is None:
- alpha,beta=np.zeros(na),np.zeros(nb)
+ alpha, beta = np.zeros(na), np.zeros(nb)
else:
- alpha,beta=warmstart
-
+ alpha, beta = warmstart
- def get_K(alpha,beta):
+ def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg)
+ return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg)
- #print(np.min(K))
- def get_reg(n): # exponential decreasing
- return (epsilon0-reg)*np.exp(-n)+reg
+ # print(np.min(K))
+ def get_reg(n): # exponential decreasing
+ return (epsilon0 - reg) * np.exp(-n) + reg
- loop=1
+ loop = 1
cpt = 0
- err=1
+ err = 1
while loop:
- regi=get_reg(cpt)
+ regi = get_reg(cpt)
- G,logi=sinkhorn_stabilized(a,b, M, regi, numItermax = numInnerItermax, stopThr=1e-9,warmstart=(alpha,beta), verbose=False,print_period=20,tau=tau, log=True)
+ G, logi = sinkhorn_stabilized(a, b, M, regi, numItermax=numInnerItermax, stopThr=1e-9, warmstart=(
+ alpha, beta), verbose=False, print_period=20, tau=tau, log=True)
- alpha=logi['alpha']
- beta=logi['beta']
+ alpha = logi['alpha']
+ beta = logi['beta']
- if cpt>=numItermax:
- loop=False
+ if cpt >= numItermax:
+ loop = False
- if cpt%(print_period)==0: # spsion nearly converged
- # we can speed up the process by checking for the error only all the 10th iterations
+ if cpt % (print_period) == 0: # spsion nearly converged
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
transp = G
- err = np.linalg.norm((np.sum(transp,axis=0)-b))**2+np.linalg.norm((np.sum(transp,axis=1)-a))**2
+ err = np.linalg.norm(
+ (np.sum(transp, axis=0) - b))**2 + np.linalg.norm((np.sum(transp, axis=1) - a))**2
if log:
log['err'].append(err)
if verbose:
- if cpt%(print_period*10) ==0:
- print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
- print('{:5d}|{:8e}|'.format(cpt,err))
+ if cpt % (print_period * 10) == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
- if err<=stopThr and cpt>numItermin:
- loop=False
+ if err <= stopThr and cpt > numItermin:
+ loop = False
- cpt = cpt +1
- #print('err=',err,' cpt=',cpt)
+ cpt = cpt + 1
+ # print('err=',err,' cpt=',cpt)
if log:
- log['alpha']=alpha
- log['beta']=beta
- log['warmstart']=(log['alpha'],log['beta'])
- return G,log
+ log['alpha'] = alpha
+ log['beta'] = beta
+ log['warmstart'] = (log['alpha'], log['beta'])
+ return G, log
else:
return G
-def geometricBar(weights,alldistribT):
+def geometricBar(weights, alldistribT):
"""return the weighted geometric mean of distributions"""
- assert(len(weights)==alldistribT.shape[1])
- return np.exp(np.dot(np.log(alldistribT),weights.T))
+ assert(len(weights) == alldistribT.shape[1])
+ return np.exp(np.dot(np.log(alldistribT), weights.T))
+
def geometricMean(alldistribT):
"""return the geometric mean of distributions"""
- return np.exp(np.mean(np.log(alldistribT),axis=1))
+ return np.exp(np.mean(np.log(alldistribT), axis=1))
-def projR(gamma,p):
+
+def projR(gamma, p):
"""return the KL projection on the row constrints """
- return np.multiply(gamma.T,p/np.maximum(np.sum(gamma,axis=1),1e-10)).T
+ return np.multiply(gamma.T, p / np.maximum(np.sum(gamma, axis=1), 1e-10)).T
+
-def projC(gamma,q):
+def projC(gamma, q):
"""return the KL projection on the column constrints """
- return np.multiply(gamma,q/np.maximum(np.sum(gamma,axis=0),1e-10))
+ return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10))
-def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=False,log=False):
+def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False):
"""Compute the entropic regularized wasserstein barycenter of distributions A
The function solves the following optimization problem:
@@ -837,49 +862,49 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=Fa
"""
-
if weights is None:
- weights=np.ones(A.shape[1])/A.shape[1]
+ weights = np.ones(A.shape[1]) / A.shape[1]
else:
- assert(len(weights)==A.shape[1])
+ assert(len(weights) == A.shape[1])
if log:
- log={'err':[]}
+ log = {'err': []}
- #M = M/np.median(M) # suggested by G. Peyre
- K = np.exp(-M/reg)
+ # M = M/np.median(M) # suggested by G. Peyre
+ K = np.exp(-M / reg)
cpt = 0
- err=1
+ err = 1
- UKv=np.dot(K,np.divide(A.T,np.sum(K,axis=0)).T)
- u = (geometricMean(UKv)/UKv.T).T
+ UKv = np.dot(K, np.divide(A.T, np.sum(K, axis=0)).T)
+ u = (geometricMean(UKv) / UKv.T).T
- while (err>stopThr and cpt<numItermax):
- cpt = cpt +1
- UKv=u*np.dot(K,np.divide(A,np.dot(K,u)))
- u = (u.T*geometricBar(weights,UKv)).T/UKv
+ while (err > stopThr and cpt < numItermax):
+ cpt = cpt + 1
+ UKv = u * np.dot(K, np.divide(A, np.dot(K, u)))
+ u = (u.T * geometricBar(weights, UKv)).T / UKv
- if cpt%10==1:
- err=np.sum(np.std(UKv,axis=1))
+ if cpt % 10 == 1:
+ err = np.sum(np.std(UKv, axis=1))
# log and verbose print
if log:
log['err'].append(err)
if verbose:
- if cpt%200 ==0:
- print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
- print('{:5d}|{:8e}|'.format(cpt,err))
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
if log:
- log['niter']=cpt
- return geometricBar(weights,UKv),log
+ log['niter'] = cpt
+ return geometricBar(weights, UKv), log
else:
- return geometricBar(weights,UKv)
+ return geometricBar(weights, UKv)
-def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=False,log=False):
+def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False):
"""
Compute the unmixing of an observation with a given dictionary using Wasserstein distance
@@ -942,44 +967,45 @@ def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=Fal
"""
- #M = M/np.median(M)
- K = np.exp(-M/reg)
+ # M = M/np.median(M)
+ K = np.exp(-M / reg)
- #M0 = M0/np.median(M0)
- K0 = np.exp(-M0/reg0)
+ # M0 = M0/np.median(M0)
+ K0 = np.exp(-M0 / reg0)
old = h0
- err=1
- cpt=0
- #log = {'niter':0, 'all_err':[]}
+ err = 1
+ cpt = 0
+ # log = {'niter':0, 'all_err':[]}
if log:
- log={'err':[]}
-
-
- while (err>stopThr and cpt<numItermax):
- K = projC(K,a)
- K0 = projC(K0,h0)
- new = np.sum(K0,axis=1)
- inv_new = np.dot(D,new) # we recombine the current selection from dictionnary
- other = np.sum(K,axis=1)
- delta = np.exp(alpha*np.log(other)+(1-alpha)*np.log(inv_new)) # geometric interpolation
- K = projR(K,delta)
- K0 = np.dot(np.diag(np.dot(D.T,delta/inv_new)),K0)
-
- err=np.linalg.norm(np.sum(K0,axis=1)-old)
+ log = {'err': []}
+
+ while (err > stopThr and cpt < numItermax):
+ K = projC(K, a)
+ K0 = projC(K0, h0)
+ new = np.sum(K0, axis=1)
+ # we recombine the current selection from dictionnary
+ inv_new = np.dot(D, new)
+ other = np.sum(K, axis=1)
+ # geometric interpolation
+ delta = np.exp(alpha * np.log(other) + (1 - alpha) * np.log(inv_new))
+ K = projR(K, delta)
+ K0 = np.dot(np.diag(np.dot(D.T, delta / inv_new)), K0)
+
+ err = np.linalg.norm(np.sum(K0, axis=1) - old)
old = new
if log:
log['err'].append(err)
if verbose:
- if cpt%200 ==0:
- print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
- print('{:5d}|{:8e}|'.format(cpt,err))
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
- cpt = cpt+1
+ cpt = cpt + 1
if log:
- log['niter']=cpt
- return np.sum(K0,axis=1),log
+ log['niter'] = cpt
+ return np.sum(K0, axis=1), log
else:
- return np.sum(K0,axis=1)
+ return np.sum(K0, axis=1)