summaryrefslogtreecommitdiff
path: root/test/test_unbalanced.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r--test/test_unbalanced.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index db59504..02b3fc3 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -1,6 +1,7 @@
"""Tests for module Unbalanced OT with entropy regularization"""
# Author: Hicham Janati <hicham.janati@inria.fr>
+# Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
#
# License: MIT License
@@ -286,3 +287,52 @@ def test_implemented_methods(nx):
method=method)
barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
method=method)
+
+
+def test_mm_convergence(nx):
+ n = 100
+ rng = np.random.RandomState(42)
+ x = rng.randn(n, 2)
+ rng = np.random.RandomState(75)
+ y = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+ b = ot.utils.unif(n)
+
+ M = ot.dist(x, y)
+ M = M / M.max()
+ reg_m = 100
+ a, b, M = nx.from_numpy(a, b, M)
+
+ G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl',
+ verbose=True, log=True)
+ loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2(
+ a, b, M, reg_m, div='kl', verbose=True))
+ G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2',
+ verbose=False, log=True)
+
+ # check if the marginals come close to the true ones when large reg
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03)
+
+ # check if mm_unbalanced2 returns the correct loss
+ np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl,
+ atol=1e-5)
+
+ # check in case no histogram is provided
+ a_np, b_np = np.array([]), np.array([])
+ a, b = nx.from_numpy(a_np, b_np)
+
+ G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl')
+ G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2')
+ np.testing.assert_allclose(G_kl_null, G_kl)
+ np.testing.assert_allclose(G_l2_null, G_l2)
+
+ # test when G0 is given
+ G0 = ot.emd(a, b, M)
+ reg_m = 10000
+ G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0)
+ G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0)
+ np.testing.assert_allclose(G0, G_kl, atol=1e-05)
+ np.testing.assert_allclose(G0, G_l2, atol=1e-05)