diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 68 |
1 files changed, 35 insertions, 33 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 05f4d9d..f844f03 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -918,7 +918,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, else: return geometricBar(weights, UKv) -def convolutional_barycenter2d(A,reg,weights=None,numItermax = 10000, stopThr=1e-9, verbose=False, log=False): + +def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, verbose=False, log=False): """Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images. @@ -979,51 +980,52 @@ def convolutional_barycenter2d(A,reg,weights=None,numItermax = 10000, stopThr=1e if log: log = {'err': []} - b=np.zeros_like(A[0,:,:]) - U=np.ones_like(A) - KV=np.ones_like(A) - threshold = 1e-30 # in order to avoids numerical precision issues + b = np.zeros_like(A[0, :, :]) + U = np.ones_like(A) + KV = np.ones_like(A) + threshold = 1e-30 # in order to avoids numerical precision issues cpt = 0 - err=1 - - # build the convolution operator - t = np.linspace(0,1,A.shape[1]) - [Y,X] = np.meshgrid(t,t) - xi1 = np.exp(-(X-Y)**2/reg) - K = lambda x: np.dot(np.dot(xi1,x),xi1) - - while (err>stopThr and cpt<numItermax): - - bold=b - cpt = cpt +1 - - b=np.zeros_like(A[0,:,:]) + err = 1 + + # build the convolution operator + t = np.linspace(0, 1, A.shape[1]) + [Y, X] = np.meshgrid(t, t) + xi1 = np.exp(-(X - Y)**2 / reg) + + def K(x): return np.dot(np.dot(xi1, x), xi1) + + while (err > stopThr and cpt < numItermax): + + bold = b + cpt = cpt + 1 + + b = np.zeros_like(A[0, :, :]) for r in range(A.shape[0]): - KV[r,:,:]=K(A[r,:,:]/np.maximum(threshold,K(U[r,:,:]))) - b += weights[r] * np.log(np.maximum(threshold, U[r,:,:]*KV[r,:,:])) + KV[r, :, :] = K(A[r, :, :] / np.maximum(threshold, K(U[r, :, :]))) + b += weights[r] * np.log(np.maximum(threshold, U[r, :, :] * KV[r, :, :])) b = np.exp(b) for r in range(A.shape[0]): - U[r,:,:]=b/np.maximum(threshold,KV[r,:,:]) - - if cpt%10==1: - err=np.sum(np.abs(bold-b)) + U[r, :, :] = b / np.maximum(threshold, KV[r, :, :]) + + if cpt % 10 == 1: + err = np.sum(np.abs(bold - b)) # log and verbose print if log: log['err'].append(err) if verbose: - if cpt%200 ==0: - print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) - print('{:5d}|{:8e}|'.format(cpt,err)) + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) if log: - log['niter']=cpt - log['U']=U - return b,log + log['niter'] = cpt + log['U'] = U + return b, log else: - return b - + return b + def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False): |