diff options
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r-- | test/test_unbalanced.py | 61 |
1 files changed, 43 insertions, 18 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 02b3fc3..b76d738 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -5,6 +5,7 @@ # # License: MIT License +import itertools import numpy as np import ot import pytest @@ -289,32 +290,55 @@ def test_implemented_methods(nx): method=method) +@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) +def test_lbfgsb_unbalanced(nx, reg_div, regm_div): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + + M = ot.dist(xs, xt) + + a = ot.unif(5) + b = ot.unif(6) + + G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False) + + ab, bb, Mb = nx.from_numpy(a, b, M) + + Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False) + + np.testing.assert_allclose(G, nx.to_numpy(Gb)) + + def test_mm_convergence(nx): n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) rng = np.random.RandomState(75) y = rng.randn(n, 2) - a = ot.utils.unif(n) - b = ot.utils.unif(n) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) M = ot.dist(x, y) M = M / M.max() reg_m = 100 - a, b, M = nx.from_numpy(a, b, M) + a, b, M = nx.from_numpy(a_np, b_np, M) G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', - verbose=True, log=True) - loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2( - a, b, M, reg_m, div='kl', verbose=True)) + verbose=False, log=True) + loss_kl = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div='kl', verbose=True) + ) G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False, log=True) # check if the marginals come close to the true ones when large reg - np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b_np, atol=1e-03) # check if mm_unbalanced2 returns the correct loss np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl, @@ -324,15 +348,16 @@ def test_mm_convergence(nx): a_np, b_np = np.array([]), np.array([]) a, b = nx.from_numpy(a_np, b_np) - G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl') - G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2') - np.testing.assert_allclose(G_kl_null, G_kl) - np.testing.assert_allclose(G_l2_null, G_l2) + G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', verbose=False) + G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False) + np.testing.assert_allclose(nx.to_numpy(G_kl_null), nx.to_numpy(G_kl)) + np.testing.assert_allclose(nx.to_numpy(G_l2_null), nx.to_numpy(G_l2)) # test when G0 is given G0 = ot.emd(a, b, M) + G0_np = nx.to_numpy(G0) reg_m = 10000 - G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0) - G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0) - np.testing.assert_allclose(G0, G_kl, atol=1e-05) - np.testing.assert_allclose(G0, G_l2, atol=1e-05) + G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0, verbose=False) + G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0, verbose=False) + np.testing.assert_allclose(G0_np, nx.to_numpy(G_kl), atol=1e-05) + np.testing.assert_allclose(G0_np, nx.to_numpy(G_l2), atol=1e-05) |