From 1f307594244dd4c274b64d028823cbcfff302f37 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Wed, 1 Jun 2022 08:52:47 +0200 Subject: [MRG] numItermax in 64 bits in EMD solver (#380) * Correct test_mm_convergence for cupy * Fix bug where number of iterations is limited to 2^31 * Update RELEASES.md * Replace size_t with long long * Use uint64_t instead of long long --- test/test_unbalanced.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) (limited to 'test/test_unbalanced.py') diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 02b3fc3..fc40df0 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -295,26 +295,27 @@ def test_mm_convergence(nx): x = rng.randn(n, 2) rng = np.random.RandomState(75) y = rng.randn(n, 2) - a = ot.utils.unif(n) - b = ot.utils.unif(n) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) M = ot.dist(x, y) M = M / M.max() reg_m = 100 - a, b, M = nx.from_numpy(a, b, M) + a, b, M = nx.from_numpy(a_np, b_np, M) G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', - verbose=True, log=True) - loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2( - a, b, M, reg_m, div='kl', verbose=True)) + verbose=False, log=True) + loss_kl = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div='kl', verbose=True) + ) G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False, log=True) # check if the marginals come close to the true ones when large reg - np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b_np, atol=1e-03) # check if mm_unbalanced2 returns the correct loss np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl, @@ -324,15 +325,16 @@ def test_mm_convergence(nx): a_np, b_np = np.array([]), np.array([]) a, b = nx.from_numpy(a_np, b_np) - G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl') - G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2') - np.testing.assert_allclose(G_kl_null, G_kl) - np.testing.assert_allclose(G_l2_null, G_l2) + G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', verbose=False) + G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False) + np.testing.assert_allclose(nx.to_numpy(G_kl_null), nx.to_numpy(G_kl)) + np.testing.assert_allclose(nx.to_numpy(G_l2_null), nx.to_numpy(G_l2)) # test when G0 is given G0 = ot.emd(a, b, M) + G0_np = nx.to_numpy(G0) reg_m = 10000 - G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0) - G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0) - np.testing.assert_allclose(G0, G_kl, atol=1e-05) - np.testing.assert_allclose(G0, G_l2, atol=1e-05) + G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0, verbose=False) + G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0, verbose=False) + np.testing.assert_allclose(G0_np, nx.to_numpy(G_kl), atol=1e-05) + np.testing.assert_allclose(G0_np, nx.to_numpy(G_l2), atol=1e-05) -- cgit v1.2.3