diff options
author | Alexandre Gramfort <alexandre.gramfort@m4x.org> | 2017-07-12 23:52:57 +0200 |
---|---|---|
committer | Alexandre Gramfort <alexandre.gramfort@m4x.org> | 2017-07-20 14:08:02 +0200 |
commit | 37ca3142a4c8c808382f5eb1c23bf198c3e4610e (patch) | |
tree | 55c2b56581b8f0f607d8e0a50778267baa4960e3 /ot/bregman.py | |
parent | da21b9888e77f7512727a4f50c60bd475e2c9606 (diff) |
pep8
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 495 |
1 files changed, 257 insertions, 238 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 0d68602..fe10880 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -5,7 +5,8 @@ Bregman projections for regularized OT import numpy as np -def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): + +def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): u""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -92,22 +93,28 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver """ - if method.lower()=='sinkhorn': - sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log,**kwargs) - elif method.lower()=='sinkhorn_stabilized': - sink= lambda: sinkhorn_stabilized(a,b, M, reg,numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) - elif method.lower()=='sinkhorn_epsilon_scaling': - sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + if method.lower() == 'sinkhorn': + def sink(): + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + def sink(): + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + elif method.lower() == 'sinkhorn_epsilon_scaling': + def sink(): + return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: print('Warning : unknown method using classic Sinkhorn Knopp') - sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs) + + def sink(): + return sinkhorn_knopp(a, b, M, reg, **kwargs) return sink() -def sinkhorn2(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): + +def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): u""" Solve the entropic regularization optimal transport problem and return the loss @@ -170,7 +177,7 @@ def sinkhorn2(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ve >>> M=[[0.,1.],[1.,0.]] >>> ot.sinkhorn2(a,b,M,1) array([ 0.26894142]) - + References @@ -194,27 +201,32 @@ def sinkhorn2(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ve """ - if method.lower()=='sinkhorn': - sink= lambda: sinkhorn_knopp(a,b, M, reg,numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log,**kwargs) - elif method.lower()=='sinkhorn_stabilized': - sink= lambda: sinkhorn_stabilized(a,b, M, reg,numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) - elif method.lower()=='sinkhorn_epsilon_scaling': - sink= lambda: sinkhorn_epsilon_scaling(a,b, M, reg,numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + if method.lower() == 'sinkhorn': + def sink(): + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + def sink(): + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + elif method.lower() == 'sinkhorn_epsilon_scaling': + def sink(): + return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: print('Warning : unknown method using classic Sinkhorn Knopp') - sink= lambda: sinkhorn_knopp(a,b, M, reg, **kwargs) - - b=np.asarray(b,dtype=np.float64) - if len(b.shape)<2: - b=b.reshape((-1,1)) + + def sink(): + return sinkhorn_knopp(a, b, M, reg, **kwargs) + + b = np.asarray(b, dtype=np.float64) + if len(b.shape) < 2: + b = b.reshape((-1, 1)) return sink() -def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): """ Solve the entropic regularization optimal transport problem and return the OT matrix @@ -290,100 +302,101 @@ def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, """ - a=np.asarray(a,dtype=np.float64) - b=np.asarray(b,dtype=np.float64) - M=np.asarray(M,dtype=np.float64) - - - if len(a)==0: - a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0] - if len(b)==0: - b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1] + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + if len(a) == 0: + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] # init data Nini = len(a) Nfin = len(b) - if len(b.shape)>1: - nbb=b.shape[1] + if len(b.shape) > 1: + nbb = b.shape[1] else: - nbb=0 - + nbb = 0 if log: - log={'err':[]} + log = {'err': []} - # we assume that no distances are null except those of the diagonal of distances + # we assume that no distances are null except those of the diagonal of + # distances if nbb: - u = np.ones((Nini,nbb))/Nini - v = np.ones((Nfin,nbb))/Nfin + u = np.ones((Nini, nbb)) / Nini + v = np.ones((Nfin, nbb)) / Nfin else: - u = np.ones(Nini)/Nini - v = np.ones(Nfin)/Nfin + u = np.ones(Nini) / Nini + v = np.ones(Nfin) / Nfin + # print(reg) - #print(reg) + K = np.exp(-M / reg) + # print(np.min(K)) - K = np.exp(-M/reg) - #print(np.min(K)) - - Kp = (1/a).reshape(-1, 1) * K + Kp = (1 / a).reshape(-1, 1) * K cpt = 0 - err=1 - while (err>stopThr and cpt<numItermax): + err = 1 + while (err > stopThr and cpt < numItermax): uprev = u vprev = v KtransposeU = np.dot(K.T, u) v = np.divide(b, KtransposeU) - u = 1./np.dot(Kp,v) + u = 1. / np.dot(Kp, v) - if (np.any(KtransposeU==0) or - np.any(np.isnan(u)) or np.any(np.isnan(v)) or - np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (np.any(KtransposeU == 0) or + np.any(np.isnan(u)) or np.any(np.isnan(v)) or + np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) u = uprev v = vprev break - if cpt%10==0: - # we can speed up the process by checking for the error only all the 10th iterations + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations if nbb: - err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2) + err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ + np.sum((v - vprev)**2) / np.sum((v)**2) else: transp = u.reshape(-1, 1) * (K * v) - err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 + err = np.linalg.norm((np.sum(transp, axis=0) - b))**2 if log: log['err'].append(err) if verbose: - if cpt%200 ==0: - print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) - print('{:5d}|{:8e}|'.format(cpt,err)) - cpt = cpt +1 + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 if log: - log['u']=u - log['v']=v + log['u'] = u + log['v'] = v - if nbb: #return only loss - res=np.zeros((nbb)) + if nbb: # return only loss + res = np.zeros((nbb)) for i in range(nbb): - res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M) + res[i] = np.sum( + u[:, i].reshape((-1, 1)) * K * v[:, i].reshape((1, -1)) * M) if log: - return res,log + return res, log else: return res - else: # return OT matrix + else: # return OT matrix if log: - return u.reshape((-1,1))*K*v.reshape((1,-1)),log + return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log else: - return u.reshape((-1,1))*K*v.reshape((1,-1)) + return u.reshape((-1, 1)) * K * v.reshape((1, -1)) -def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=20, log=False,**kwargs): +def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, log=False, **kwargs): """ Solve the entropic regularization OT problem with log stabilization @@ -468,21 +481,21 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war """ - a=np.asarray(a,dtype=np.float64) - b=np.asarray(b,dtype=np.float64) - M=np.asarray(M,dtype=np.float64) + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) - if len(a)==0: - a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0] - if len(b)==0: - b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1] + if len(a) == 0: + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] # test if multiple target - if len(b.shape)>1: - nbb=b.shape[1] - a=a[:,np.newaxis] + if len(b.shape) > 1: + nbb = b.shape[1] + a = a[:, np.newaxis] else: - nbb=0 + nbb = 0 # init data na = len(a) @@ -490,81 +503,80 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war cpt = 0 if log: - log={'err':[]} + log = {'err': []} - # we assume that no distances are null except those of the diagonal of distances + # we assume that no distances are null except those of the diagonal of + # distances if warmstart is None: - alpha,beta=np.zeros(na),np.zeros(nb) + alpha, beta = np.zeros(na), np.zeros(nb) else: - alpha,beta=warmstart + alpha, beta = warmstart if nbb: - u,v = np.ones((na,nbb))/na,np.ones((nb,nbb))/nb + u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb else: - u,v = np.ones(na)/na,np.ones(nb)/nb + u, v = np.ones(na) / na, np.ones(nb) / nb - def get_K(alpha,beta): + def get_K(alpha, beta): """log space computation""" - return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg) + return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg) - def get_Gamma(alpha,beta,u,v): + def get_Gamma(alpha, beta, u, v): """log space gamma computation""" - return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg+np.log(u.reshape((na,1)))+np.log(v.reshape((1,nb)))) + return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb)))) - #print(np.min(K)) + # print(np.min(K)) - K=get_K(alpha,beta) + K = get_K(alpha, beta) transp = K - loop=1 + loop = 1 cpt = 0 - err=1 + err = 1 while loop: - - uprev = u vprev = v # sinkhorn update - v = b/(np.dot(K.T,u)+1e-16) - u = a/(np.dot(K,v)+1e-16) - + v = b / (np.dot(K.T, u) + 1e-16) + u = a / (np.dot(K, v) + 1e-16) # remove numerical problems and store them in K - if np.abs(u).max()>tau or np.abs(v).max()>tau: + if np.abs(u).max() > tau or np.abs(v).max() > tau: if nbb: - alpha,beta=alpha+reg*np.max(np.log(u),1),beta+reg*np.max(np.log(v)) + alpha, beta = alpha + reg * \ + np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) else: - alpha,beta=alpha+reg*np.log(u),beta+reg*np.log(v) + alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) if nbb: - u,v = np.ones((na,nbb))/na,np.ones((nb,nbb))/nb + u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb else: - u,v = np.ones(na)/na,np.ones(nb)/nb - K=get_K(alpha,beta) - + u, v = np.ones(na) / na, np.ones(nb) / nb + K = get_K(alpha, beta) - if cpt%print_period==0: - # we can speed up the process by checking for the error only all the 10th iterations + if cpt % print_period == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations if nbb: - err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2) + err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ + np.sum((v - vprev)**2) / np.sum((v)**2) else: - transp = get_Gamma(alpha,beta,u,v) - err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 + transp = get_Gamma(alpha, beta, u, v) + err = np.linalg.norm((np.sum(transp, axis=0) - b))**2 if log: log['err'].append(err) if verbose: - if cpt%(print_period*20) ==0: - print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) - print('{:5d}|{:8e}|'.format(cpt,err)) + if cpt % (print_period * 20) == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + if err <= stopThr: + loop = False - if err<=stopThr: - loop=False - - if cpt>=numItermax: - loop=False - + if cpt >= numItermax: + loop = False if np.any(np.isnan(u)) or np.any(np.isnan(v)): # we have reached the machine precision @@ -574,34 +586,34 @@ def sinkhorn_stabilized(a,b, M, reg, numItermax = 1000,tau=1e3, stopThr=1e-9,war v = vprev break - cpt = cpt +1 - + cpt = cpt + 1 #print('err=',err,' cpt=',cpt) if log: - log['logu']=alpha/reg+np.log(u) - log['logv']=beta/reg+np.log(v) - log['alpha']=alpha+reg*np.log(u) - log['beta']=beta+reg*np.log(v) - log['warmstart']=(log['alpha'],log['beta']) + log['logu'] = alpha / reg + np.log(u) + log['logv'] = beta / reg + np.log(v) + log['alpha'] = alpha + reg * np.log(u) + log['beta'] = beta + reg * np.log(v) + log['warmstart'] = (log['alpha'], log['beta']) if nbb: - res=np.zeros((nbb)) + res = np.zeros((nbb)) for i in range(nbb): - res[i]=np.sum(get_Gamma(alpha,beta,u[:,i],v[:,i])*M) - return res,log + res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + return res, log else: - return get_Gamma(alpha,beta,u,v),log + return get_Gamma(alpha, beta, u, v), log else: if nbb: - res=np.zeros((nbb)) + res = np.zeros((nbb)) for i in range(nbb): - res[i]=np.sum(get_Gamma(alpha,beta,u[:,i],v[:,i])*M) + res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) return res else: - return get_Gamma(alpha,beta,u,v) + return get_Gamma(alpha, beta, u, v) + -def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInnerItermax = 100,tau=1e3, stopThr=1e-9,warmstart=None, verbose=False,print_period=10, log=False,**kwargs): +def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs): """ Solve the entropic regularization optimal transport problem with log stabilization and epsilon scaling. @@ -690,14 +702,14 @@ def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInn """ - a=np.asarray(a,dtype=np.float64) - b=np.asarray(b,dtype=np.float64) - M=np.asarray(M,dtype=np.float64) + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) - if len(a)==0: - a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0] - if len(b)==0: - b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1] + if len(a) == 0: + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] # init data na = len(a) @@ -705,88 +717,94 @@ def sinkhorn_epsilon_scaling(a,b, M, reg, numItermax = 100, epsilon0=1e4, numInn # nrelative umerical precision with 64 bits numItermin = 35 - numItermax=max(numItermin,numItermax) # ensure that last velue is exact - + numItermax = max(numItermin, numItermax) # ensure that last velue is exact cpt = 0 if log: - log={'err':[]} + log = {'err': []} - # we assume that no distances are null except those of the diagonal of distances + # we assume that no distances are null except those of the diagonal of + # distances if warmstart is None: - alpha,beta=np.zeros(na),np.zeros(nb) + alpha, beta = np.zeros(na), np.zeros(nb) else: - alpha,beta=warmstart - + alpha, beta = warmstart - def get_K(alpha,beta): + def get_K(alpha, beta): """log space computation""" - return np.exp(-(M-alpha.reshape((na,1))-beta.reshape((1,nb)))/reg) + return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg) - #print(np.min(K)) - def get_reg(n): # exponential decreasing - return (epsilon0-reg)*np.exp(-n)+reg + # print(np.min(K)) + def get_reg(n): # exponential decreasing + return (epsilon0 - reg) * np.exp(-n) + reg - loop=1 + loop = 1 cpt = 0 - err=1 + err = 1 while loop: - regi=get_reg(cpt) + regi = get_reg(cpt) - G,logi=sinkhorn_stabilized(a,b, M, regi, numItermax = numInnerItermax, stopThr=1e-9,warmstart=(alpha,beta), verbose=False,print_period=20,tau=tau, log=True) + G, logi = sinkhorn_stabilized(a, b, M, regi, numItermax=numInnerItermax, stopThr=1e-9, warmstart=( + alpha, beta), verbose=False, print_period=20, tau=tau, log=True) - alpha=logi['alpha'] - beta=logi['beta'] + alpha = logi['alpha'] + beta = logi['beta'] - if cpt>=numItermax: - loop=False + if cpt >= numItermax: + loop = False - if cpt%(print_period)==0: # spsion nearly converged - # we can speed up the process by checking for the error only all the 10th iterations + if cpt % (print_period) == 0: # spsion nearly converged + # we can speed up the process by checking for the error only all + # the 10th iterations transp = G - err = np.linalg.norm((np.sum(transp,axis=0)-b))**2+np.linalg.norm((np.sum(transp,axis=1)-a))**2 + err = np.linalg.norm( + (np.sum(transp, axis=0) - b))**2 + np.linalg.norm((np.sum(transp, axis=1) - a))**2 if log: log['err'].append(err) if verbose: - if cpt%(print_period*10) ==0: - print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) - print('{:5d}|{:8e}|'.format(cpt,err)) + if cpt % (print_period * 10) == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) - if err<=stopThr and cpt>numItermin: - loop=False + if err <= stopThr and cpt > numItermin: + loop = False - cpt = cpt +1 + cpt = cpt + 1 #print('err=',err,' cpt=',cpt) if log: - log['alpha']=alpha - log['beta']=beta - log['warmstart']=(log['alpha'],log['beta']) - return G,log + log['alpha'] = alpha + log['beta'] = beta + log['warmstart'] = (log['alpha'], log['beta']) + return G, log else: return G -def geometricBar(weights,alldistribT): +def geometricBar(weights, alldistribT): """return the weighted geometric mean of distributions""" - assert(len(weights)==alldistribT.shape[1]) - return np.exp(np.dot(np.log(alldistribT),weights.T)) + assert(len(weights) == alldistribT.shape[1]) + return np.exp(np.dot(np.log(alldistribT), weights.T)) + def geometricMean(alldistribT): """return the geometric mean of distributions""" - return np.exp(np.mean(np.log(alldistribT),axis=1)) + return np.exp(np.mean(np.log(alldistribT), axis=1)) -def projR(gamma,p): + +def projR(gamma, p): """return the KL projection on the row constrints """ - return np.multiply(gamma.T,p/np.maximum(np.sum(gamma,axis=1),1e-10)).T + return np.multiply(gamma.T, p / np.maximum(np.sum(gamma, axis=1), 1e-10)).T + -def projC(gamma,q): +def projC(gamma, q): """return the KL projection on the column constrints """ - return np.multiply(gamma,q/np.maximum(np.sum(gamma,axis=0),1e-10)) + return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10)) -def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=False,log=False): +def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): """Compute the entropic regularized wasserstein barycenter of distributions A The function solves the following optimization problem: @@ -837,49 +855,49 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=Fa """ - if weights is None: - weights=np.ones(A.shape[1])/A.shape[1] + weights = np.ones(A.shape[1]) / A.shape[1] else: - assert(len(weights)==A.shape[1]) + assert(len(weights) == A.shape[1]) if log: - log={'err':[]} + log = {'err': []} - #M = M/np.median(M) # suggested by G. Peyre - K = np.exp(-M/reg) + # M = M/np.median(M) # suggested by G. Peyre + K = np.exp(-M / reg) cpt = 0 - err=1 + err = 1 - UKv=np.dot(K,np.divide(A.T,np.sum(K,axis=0)).T) - u = (geometricMean(UKv)/UKv.T).T + UKv = np.dot(K, np.divide(A.T, np.sum(K, axis=0)).T) + u = (geometricMean(UKv) / UKv.T).T - while (err>stopThr and cpt<numItermax): - cpt = cpt +1 - UKv=u*np.dot(K,np.divide(A,np.dot(K,u))) - u = (u.T*geometricBar(weights,UKv)).T/UKv + while (err > stopThr and cpt < numItermax): + cpt = cpt + 1 + UKv = u * np.dot(K, np.divide(A, np.dot(K, u))) + u = (u.T * geometricBar(weights, UKv)).T / UKv - if cpt%10==1: - err=np.sum(np.std(UKv,axis=1)) + if cpt % 10 == 1: + err = np.sum(np.std(UKv, axis=1)) # log and verbose print if log: log['err'].append(err) if verbose: - if cpt%200 ==0: - print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) - print('{:5d}|{:8e}|'.format(cpt,err)) + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) if log: - log['niter']=cpt - return geometricBar(weights,UKv),log + log['niter'] = cpt + return geometricBar(weights, UKv), log else: - return geometricBar(weights,UKv) + return geometricBar(weights, UKv) -def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=False,log=False): +def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False): """ Compute the unmixing of an observation with a given dictionary using Wasserstein distance @@ -943,43 +961,44 @@ def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=Fal """ #M = M/np.median(M) - K = np.exp(-M/reg) + K = np.exp(-M / reg) #M0 = M0/np.median(M0) - K0 = np.exp(-M0/reg0) + K0 = np.exp(-M0 / reg0) old = h0 - err=1 - cpt=0 + err = 1 + cpt = 0 #log = {'niter':0, 'all_err':[]} if log: - log={'err':[]} - - - while (err>stopThr and cpt<numItermax): - K = projC(K,a) - K0 = projC(K0,h0) - new = np.sum(K0,axis=1) - inv_new = np.dot(D,new) # we recombine the current selection from dictionnary - other = np.sum(K,axis=1) - delta = np.exp(alpha*np.log(other)+(1-alpha)*np.log(inv_new)) # geometric interpolation - K = projR(K,delta) - K0 = np.dot(np.diag(np.dot(D.T,delta/inv_new)),K0) - - err=np.linalg.norm(np.sum(K0,axis=1)-old) + log = {'err': []} + + while (err > stopThr and cpt < numItermax): + K = projC(K, a) + K0 = projC(K0, h0) + new = np.sum(K0, axis=1) + # we recombine the current selection from dictionnary + inv_new = np.dot(D, new) + other = np.sum(K, axis=1) + # geometric interpolation + delta = np.exp(alpha * np.log(other) + (1 - alpha) * np.log(inv_new)) + K = projR(K, delta) + K0 = np.dot(np.diag(np.dot(D.T, delta / inv_new)), K0) + + err = np.linalg.norm(np.sum(K0, axis=1) - old) old = new if log: log['err'].append(err) if verbose: - if cpt%200 ==0: - print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) - print('{:5d}|{:8e}|'.format(cpt,err)) + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt+1 + cpt = cpt + 1 if log: - log['niter']=cpt - return np.sum(K0,axis=1),log + log['niter'] = cpt + return np.sum(K0, axis=1), log else: - return np.sum(K0,axis=1) + return np.sum(K0, axis=1) |