summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorNicolas Courty <ncourty@irisa.fr>2018-09-07 14:41:00 +0200
committerNicolas Courty <ncourty@irisa.fr>2018-09-07 14:41:00 +0200
commitd19295b9cb29d21e09eeb28ac4b0e61990727023 (patch)
tree43c8e7840a7b432c1fab50451cdaf1e44d22578e /ot
parentdd200d544393cb6538f789141758bdbe8251bffc (diff)
stabThr and pep8
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 5327dbc..748ac30 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -919,7 +919,7 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)
-def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
+def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
"""Compute the entropic regularized wasserstein barycenter of distributions A
where A is a collection of 2D images.
@@ -948,6 +948,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
+ stabThr : float, optional
+ Stabilization threshold to avoid numerical precision issue
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -983,7 +985,6 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
b = np.zeros_like(A[0, :, :])
U = np.ones_like(A)
KV = np.ones_like(A)
- threshold = 1e-30 # in order to avoids numerical precision issues
cpt = 0
err = 1
@@ -993,7 +994,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
[Y, X] = np.meshgrid(t, t)
xi1 = np.exp(-(X - Y)**2 / reg)
- def K(x):
+ def K(x):
return np.dot(np.dot(xi1, x), xi1)
while (err > stopThr and cpt < numItermax):
@@ -1003,11 +1004,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
b = np.zeros_like(A[0, :, :])
for r in range(A.shape[0]):
- KV[r, :, :] = K(A[r, :, :] / np.maximum(threshold, K(U[r, :, :])))
- b += weights[r] * np.log(np.maximum(threshold, U[r, :, :] * KV[r, :, :]))
+ KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :])))
+ b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :]))
b = np.exp(b)
for r in range(A.shape[0]):
- U[r, :, :] = b / np.maximum(threshold, KV[r, :, :])
+ U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :])
if cpt % 10 == 1:
err = np.sum(np.abs(bold - b))