summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py68
1 files changed, 35 insertions, 33 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 05f4d9d..f844f03 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -918,7 +918,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
else:
return geometricBar(weights, UKv)
-def convolutional_barycenter2d(A,reg,weights=None,numItermax = 10000, stopThr=1e-9, verbose=False, log=False):
+
+def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
"""Compute the entropic regularized wasserstein barycenter of distributions A
where A is a collection of 2D images.
@@ -979,51 +980,52 @@ def convolutional_barycenter2d(A,reg,weights=None,numItermax = 10000, stopThr=1e
if log:
log = {'err': []}
- b=np.zeros_like(A[0,:,:])
- U=np.ones_like(A)
- KV=np.ones_like(A)
- threshold = 1e-30 # in order to avoids numerical precision issues
+ b = np.zeros_like(A[0, :, :])
+ U = np.ones_like(A)
+ KV = np.ones_like(A)
+ threshold = 1e-30 # in order to avoids numerical precision issues
cpt = 0
- err=1
-
- # build the convolution operator
- t = np.linspace(0,1,A.shape[1])
- [Y,X] = np.meshgrid(t,t)
- xi1 = np.exp(-(X-Y)**2/reg)
- K = lambda x: np.dot(np.dot(xi1,x),xi1)
-
- while (err>stopThr and cpt<numItermax):
-
- bold=b
- cpt = cpt +1
-
- b=np.zeros_like(A[0,:,:])
+ err = 1
+
+ # build the convolution operator
+ t = np.linspace(0, 1, A.shape[1])
+ [Y, X] = np.meshgrid(t, t)
+ xi1 = np.exp(-(X - Y)**2 / reg)
+
+ def K(x): return np.dot(np.dot(xi1, x), xi1)
+
+ while (err > stopThr and cpt < numItermax):
+
+ bold = b
+ cpt = cpt + 1
+
+ b = np.zeros_like(A[0, :, :])
for r in range(A.shape[0]):
- KV[r,:,:]=K(A[r,:,:]/np.maximum(threshold,K(U[r,:,:])))
- b += weights[r] * np.log(np.maximum(threshold, U[r,:,:]*KV[r,:,:]))
+ KV[r, :, :] = K(A[r, :, :] / np.maximum(threshold, K(U[r, :, :])))
+ b += weights[r] * np.log(np.maximum(threshold, U[r, :, :] * KV[r, :, :]))
b = np.exp(b)
for r in range(A.shape[0]):
- U[r,:,:]=b/np.maximum(threshold,KV[r,:,:])
-
- if cpt%10==1:
- err=np.sum(np.abs(bold-b))
+ U[r, :, :] = b / np.maximum(threshold, KV[r, :, :])
+
+ if cpt % 10 == 1:
+ err = np.sum(np.abs(bold - b))
# log and verbose print
if log:
log['err'].append(err)
if verbose:
- if cpt%200 ==0:
- print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
- print('{:5d}|{:8e}|'.format(cpt,err))
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
if log:
- log['niter']=cpt
- log['U']=U
- return b,log
+ log['niter'] = cpt
+ log['U'] = U
+ return b, log
else:
- return b
-
+ return b
+
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
stopThr=1e-3, verbose=False, log=False):