summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorLaetitia Chapel <laetitia.chapel@univ-ubs.fr>2020-04-09 14:14:34 +0200
committerGitHub <noreply@github.com>2020-04-09 14:14:34 +0200
commitfff2463aafd58343c8bc2ed7875622e16a8c1cee (patch)
treeb23efef253c4cc42c13bf3f7aad671f27bf43a3d /ot/bregman.py
parent9f63ee92e281427ab3d520f75bb9c3406b547365 (diff)
parent4cd4e09f89fe6f95a07d632365612b797ab760da (diff)
Merge branch 'master' into partial-W-and-GW
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 2707b7c..d5e3563 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -9,6 +9,7 @@ Bregman projections for regularized OT
# Titouan Vayer <titouan.vayer@irisa.fr>
# Hicham Janati <hicham.janati@inria.fr>
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
+# Alexander Tong <alexander.tong@yale.edu>
#
# License: MIT License
@@ -1346,12 +1347,17 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
err = 1
# build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
t = np.linspace(0, 1, A.shape[1])
[Y, X] = np.meshgrid(t, t)
xi1 = np.exp(-(X - Y)**2 / reg)
+ t = np.linspace(0, 1, A.shape[2])
+ [Y, X] = np.meshgrid(t, t)
+ xi2 = np.exp(-(X - Y)**2 / reg)
+
def K(x):
- return np.dot(np.dot(xi1, x), xi1)
+ return np.dot(np.dot(xi1, x), xi2)
while (err > stopThr and cpt < numItermax):