summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorMokhtar Z. Alaya <mzalaya@Mokhtars-iMac.local>2020-01-18 09:04:48 +0100
committerMokhtar Z. Alaya <mzalaya@Mokhtars-iMac.local>2020-01-18 09:04:48 +0100
commit7f7b1c547b54b394db975f4ff9d0287904a7b820 (patch)
treebc9194bde672b43623b6a5c62e659e53f0d4a1a2 /test
parentb3fb1ef40a482f0989686b79373060d764b62d38 (diff)
make autopep
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py19
1 files changed, 9 insertions, 10 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index e376715..fd0679b 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -106,7 +106,6 @@ def test_sinkhorn_variants_log():
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
def test_barycenter(method):
-
n_bins = 100 # nb bins
# Gaussian distributions
@@ -133,7 +132,6 @@ def test_barycenter(method):
def test_barycenter_stabilization():
-
n_bins = 100 # nb bins
# Gaussian distributions
@@ -161,7 +159,6 @@ def test_barycenter_stabilization():
def test_wasserstein_bary_2d():
-
size = 100 # size of a square image
a1 = np.random.randn(size, size)
a1 += a1.min()
@@ -185,7 +182,6 @@ def test_wasserstein_bary_2d():
def test_unmix():
-
n_bins = 50 # nb bins
# Gaussian distributions
@@ -207,7 +203,7 @@ def test_unmix():
# wasserstein
reg = 1e-3
- um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,)
+ um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, )
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
@@ -256,7 +252,7 @@ def test_empirical_sinkhorn():
def test_empirical_sinkhorn_divergence():
- #Test sinkhorn divergence
+ # Test sinkhorn divergence
n = 10
a = ot.unif(n)
b = ot.unif(n)
@@ -348,7 +344,10 @@ def test_screenkhorn():
x = rng.randn(n, 2)
M = ot.dist(x, x)
- G_s = ot.sinkhorn(a, b, M, 1e-03)
- G_sc = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True)
- np.testing.assert_allclose(G_s.sum(0), G_sc.sum(0), atol=1e-02)
- np.testing.assert_allclose(G_s.sum(1), G_sc.sum(1), atol=1e-02) \ No newline at end of file
+ # sinkhorn
+ G_sink = ot.sinkhorn(a, b, M, 1e-03)
+ # screenkhorn
+ G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True)
+ # check marginals
+ np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
+ np.testing.assert_allclose(G_s.sum(1), G_screen.sum(1), atol=1e-02)