diff options
-rw-r--r-- | ot/bregman.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 5327dbc..748ac30 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -919,7 +919,7 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, return geometricBar(weights, UKv) -def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, verbose=False, log=False): +def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False): """Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images. @@ -948,6 +948,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 Max number of iterations stopThr : float, optional Stop threshol on error (>0) + stabThr : float, optional + Stabilization threshold to avoid numerical precision issue verbose : bool, optional Print information along iterations log : bool, optional @@ -983,7 +985,6 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 b = np.zeros_like(A[0, :, :]) U = np.ones_like(A) KV = np.ones_like(A) - threshold = 1e-30 # in order to avoids numerical precision issues cpt = 0 err = 1 @@ -993,7 +994,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 [Y, X] = np.meshgrid(t, t) xi1 = np.exp(-(X - Y)**2 / reg) - def K(x): + def K(x): return np.dot(np.dot(xi1, x), xi1) while (err > stopThr and cpt < numItermax): @@ -1003,11 +1004,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 b = np.zeros_like(A[0, :, :]) for r in range(A.shape[0]): - KV[r, :, :] = K(A[r, :, :] / np.maximum(threshold, K(U[r, :, :]))) - b += weights[r] * np.log(np.maximum(threshold, U[r, :, :] * KV[r, :, :])) + KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :]))) + b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :])) b = np.exp(b) for r in range(A.shape[0]): - U[r, :, :] = b / np.maximum(threshold, KV[r, :, :]) + U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :]) if cpt % 10 == 1: err = np.sum(np.abs(bold - b)) |