diff options
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r-- | test/test_unbalanced.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index db59504..02b3fc3 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -1,6 +1,7 @@ """Tests for module Unbalanced OT with entropy regularization""" # Author: Hicham Janati <hicham.janati@inria.fr> +# Laetitia Chapel <laetitia.chapel@univ-ubs.fr> # # License: MIT License @@ -286,3 +287,52 @@ def test_implemented_methods(nx): method=method) barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, method=method) + + +def test_mm_convergence(nx): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a = ot.utils.unif(n) + b = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + reg_m = 100 + a, b, M = nx.from_numpy(a, b, M) + + G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', + verbose=True, log=True) + loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2( + a, b, M, reg_m, div='kl', verbose=True)) + G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', + verbose=False, log=True) + + # check if the marginals come close to the true ones when large reg + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03) + + # check if mm_unbalanced2 returns the correct loss + np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl, + atol=1e-5) + + # check in case no histogram is provided + a_np, b_np = np.array([]), np.array([]) + a, b = nx.from_numpy(a_np, b_np) + + G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl') + G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2') + np.testing.assert_allclose(G_kl_null, G_kl) + np.testing.assert_allclose(G_l2_null, G_l2) + + # test when G0 is given + G0 = ot.emd(a, b, M) + reg_m = 10000 + G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0) + G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0) + np.testing.assert_allclose(G0, G_kl, atol=1e-05) + np.testing.assert_allclose(G0, G_l2, atol=1e-05) |