From 7f7b1c547b54b394db975f4ff9d0287904a7b820 Mon Sep 17 00:00:00 2001 From: "Mokhtar Z. Alaya" Date: Sat, 18 Jan 2020 09:04:48 +0100 Subject: make autopep --- test/test_bregman.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) (limited to 'test/test_bregman.py') diff --git a/test/test_bregman.py b/test/test_bregman.py index e376715..fd0679b 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -106,7 +106,6 @@ def test_sinkhorn_variants_log(): @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_barycenter(method): - n_bins = 100 # nb bins # Gaussian distributions @@ -133,7 +132,6 @@ def test_barycenter(method): def test_barycenter_stabilization(): - n_bins = 100 # nb bins # Gaussian distributions @@ -161,7 +159,6 @@ def test_barycenter_stabilization(): def test_wasserstein_bary_2d(): - size = 100 # size of a square image a1 = np.random.randn(size, size) a1 += a1.min() @@ -185,7 +182,6 @@ def test_wasserstein_bary_2d(): def test_unmix(): - n_bins = 50 # nb bins # Gaussian distributions @@ -207,7 +203,7 @@ def test_unmix(): # wasserstein reg = 1e-3 - um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,) + um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, ) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) @@ -256,7 +252,7 @@ def test_empirical_sinkhorn(): def test_empirical_sinkhorn_divergence(): - #Test sinkhorn divergence + # Test sinkhorn divergence n = 10 a = ot.unif(n) b = ot.unif(n) @@ -348,7 +344,10 @@ def test_screenkhorn(): x = rng.randn(n, 2) M = ot.dist(x, x) - G_s = ot.sinkhorn(a, b, M, 1e-03) - G_sc = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True) - np.testing.assert_allclose(G_s.sum(0), G_sc.sum(0), atol=1e-02) - np.testing.assert_allclose(G_s.sum(1), G_sc.sum(1), atol=1e-02) \ No newline at end of file + # sinkhorn + G_sink = ot.sinkhorn(a, b, M, 1e-03) + # screenkhorn + G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True) + # check marginals + np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) + np.testing.assert_allclose(G_s.sum(1), G_screen.sum(1), atol=1e-02) -- cgit v1.2.3