summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-02-02 11:53:12 +0100
committerGitHub <noreply@github.com>2022-02-02 11:53:12 +0100
commita5e0f0d40d5046a6639924347ef97e2ac80ad0c9 (patch)
treedcd35e851ec2cc3f52eedbfa58fb6970664135c9 /test/test_bregman.py
parent71a57c68ea9eb2bc948c4dd1cce9928f34bf20e8 (diff)
[MRG] Add weak OT solver (#341)
* add info in release file * update tests * pep8 * add weak OT example * update plot in doc * correction ewample with empirical sinkhorn * better thumbnail * comment from review * update documenation
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py13
1 files changed, 5 insertions, 8 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6e90aa4..1419f9b 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -60,7 +60,7 @@ def test_convergence_warning(method):
ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1)
-def test_not_impemented_method():
+def test_not_implemented_method():
# test sinkhorn
w = 10
n = w ** 2
@@ -635,7 +635,7 @@ def test_wasserstein_bary_2d(nx, method):
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
else:
- bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method)
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True)
bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
@@ -667,7 +667,7 @@ def test_wasserstein_bary_2d_debiased(nx, method):
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
else:
- bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method)
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True)
bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
@@ -940,14 +940,11 @@ def test_screenkhorn(nx):
bb = nx.from_numpy(b)
M_nx = nx.from_numpy(M, type_as=ab)
- # np sinkhorn
- G_sink_np = ot.sinkhorn(a, b, M, 1e-03)
# sinkhorn
- G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03))
+ G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1))
# screenkhorn
- G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True))
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True))
# check marginals
- np.testing.assert_allclose(G_sink_np, G_sink)
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)