summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Tong <alexanderytong@gmail.com>2020-03-05 12:05:16 -0500
committerAlex Tong <alexanderytong@gmail.com>2020-03-06 08:49:36 -0500
commitd82e6eb1af99a982a4934d6bc019a9ab4ad5c880 (patch)
tree9b5b5bb699eac9a24f222e72c42a6e18549fd7c1
parent0baf83bbff6bd0c67244b3019509fe7518fb2d75 (diff)
Fix convolutional_barycenter kernel for non-symmetric images
Add authorship
-rw-r--r--ot/bregman.py8
-rw-r--r--test/test_bregman.py7
2 files changed, 14 insertions, 1 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 2707b7c..d5e3563 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -9,6 +9,7 @@ Bregman projections for regularized OT
# Titouan Vayer <titouan.vayer@irisa.fr>
# Hicham Janati <hicham.janati@inria.fr>
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
+# Alexander Tong <alexander.tong@yale.edu>
#
# License: MIT License
@@ -1346,12 +1347,17 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
err = 1
# build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
t = np.linspace(0, 1, A.shape[1])
[Y, X] = np.meshgrid(t, t)
xi1 = np.exp(-(X - Y)**2 / reg)
+ t = np.linspace(0, 1, A.shape[2])
+ [Y, X] = np.meshgrid(t, t)
+ xi2 = np.exp(-(X - Y)**2 / reg)
+
def K(x):
- return np.dot(np.dot(xi1, x), xi1)
+ return np.dot(np.dot(xi1, x), xi2)
while (err > stopThr and cpt < numItermax):
diff --git a/test/test_bregman.py b/test/test_bregman.py
index f54ba9f..ec4388d 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -351,3 +351,10 @@ def test_screenkhorn():
# check marginals
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
+
+
+def test_convolutional_barycenter_non_square():
+ # test for image with height not equal width
+ A = np.ones((2, 2, 3)) / (2 * 3)
+ b = ot.bregman.convolutional_barycenter2d(A, 1e-03)
+ np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)