summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py13
1 files changed, 5 insertions, 8 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6e90aa4..1419f9b 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -60,7 +60,7 @@ def test_convergence_warning(method):
ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1)
-def test_not_impemented_method():
+def test_not_implemented_method():
# test sinkhorn
w = 10
n = w ** 2
@@ -635,7 +635,7 @@ def test_wasserstein_bary_2d(nx, method):
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
else:
- bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method)
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True)
bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
@@ -667,7 +667,7 @@ def test_wasserstein_bary_2d_debiased(nx, method):
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
else:
- bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method)
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True)
bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
@@ -940,14 +940,11 @@ def test_screenkhorn(nx):
bb = nx.from_numpy(b)
M_nx = nx.from_numpy(M, type_as=ab)
- # np sinkhorn
- G_sink_np = ot.sinkhorn(a, b, M, 1e-03)
# sinkhorn
- G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03))
+ G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1))
# screenkhorn
- G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True))
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True))
# check marginals
- np.testing.assert_allclose(G_sink_np, G_sink)
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)