From a5e0f0d40d5046a6639924347ef97e2ac80ad0c9 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 2 Feb 2022 11:53:12 +0100 Subject: [MRG] Add weak OT solver (#341) * add info in release file * update tests * pep8 * add weak OT example * update plot in doc * correction ewample with empirical sinkhorn * better thumbnail * comment from review * update documenation --- test/test_bregman.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) (limited to 'test/test_bregman.py') diff --git a/test/test_bregman.py b/test/test_bregman.py index 6e90aa4..1419f9b 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -60,7 +60,7 @@ def test_convergence_warning(method): ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) -def test_not_impemented_method(): +def test_not_implemented_method(): # test sinkhorn w = 10 n = w ** 2 @@ -635,7 +635,7 @@ def test_wasserstein_bary_2d(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) else: - bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True) bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) @@ -667,7 +667,7 @@ def test_wasserstein_bary_2d_debiased(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) else: - bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True) bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) @@ -940,14 +940,11 @@ def test_screenkhorn(nx): bb = nx.from_numpy(b) M_nx = nx.from_numpy(M, type_as=ab) - # np sinkhorn - G_sink_np = ot.sinkhorn(a, b, M, 1e-03) # sinkhorn - G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03)) + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True)) + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) # check marginals - np.testing.assert_allclose(G_sink_np, G_sink) np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) -- cgit v1.2.3