summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py42
1 files changed, 33 insertions, 9 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index f70df10..6aa4e08 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -57,6 +57,9 @@ def test_sinkhorn_empty():
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
+ # test empty weights greenkhorn
+ ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+
def test_sinkhorn_variants():
# test sinkhorn
@@ -106,7 +109,6 @@ def test_sinkhorn_variants_log():
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
def test_barycenter(method):
-
n_bins = 100 # nb bins
# Gaussian distributions
@@ -125,7 +127,7 @@ def test_barycenter(method):
# wasserstein
reg = 1e-2
- bary_wass = ot.bregman.barycenter(A, M, reg, weights, method=method)
+ bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True)
np.testing.assert_allclose(1, np.sum(bary_wass))
@@ -133,7 +135,6 @@ def test_barycenter(method):
def test_barycenter_stabilization():
-
n_bins = 100 # nb bins
# Gaussian distributions
@@ -154,14 +155,13 @@ def test_barycenter_stabilization():
reg = 1e-2
bar_stable = ot.bregman.barycenter(A, M, reg, weights,
method="sinkhorn_stabilized",
- stopThr=1e-8)
+ stopThr=1e-8, verbose=True)
bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
- stopThr=1e-8)
+ stopThr=1e-8, verbose=True)
np.testing.assert_allclose(bar, bar_stable)
def test_wasserstein_bary_2d():
-
size = 100 # size of a square image
a1 = np.random.randn(size, size)
a1 += a1.min()
@@ -185,7 +185,6 @@ def test_wasserstein_bary_2d():
def test_unmix():
-
n_bins = 50 # nb bins
# Gaussian distributions
@@ -207,7 +206,7 @@ def test_unmix():
# wasserstein
reg = 1e-3
- um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,)
+ um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, )
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
@@ -256,7 +255,7 @@ def test_empirical_sinkhorn():
def test_empirical_sinkhorn_divergence():
- #Test sinkhorn divergence
+ # Test sinkhorn divergence
n = 10
a = ot.unif(n)
b = ot.unif(n)
@@ -337,3 +336,28 @@ def test_implemented_methods():
ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
with pytest.raises(ValueError):
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+
+
+def test_screenkhorn():
+ # test screenkhorn
+ rng = np.random.RandomState(0)
+ n = 100
+ a = ot.unif(n)
+ b = ot.unif(n)
+
+ x = rng.randn(n, 2)
+ M = ot.dist(x, x)
+ # sinkhorn
+ G_sink = ot.sinkhorn(a, b, M, 1e-03)
+ # screenkhorn
+ G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True)
+ # check marginals
+ np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
+ np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
+
+
+def test_convolutional_barycenter_non_square():
+ # test for image with height not equal width
+ A = np.ones((2, 2, 3)) / (2 * 3)
+ b = ot.bregman.convolutional_barycenter2d(A, 1e-03)
+ np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)