summaryrefslogtreecommitdiff
path: root/test/test_unbalanced.py
diff options
context:
space:
mode:
authorLaetitia Chapel <laetitia.chapel@univ-ubs.fr>2022-04-11 15:38:18 +0200
committerGitHub <noreply@github.com>2022-04-11 15:38:18 +0200
commitac4cf442735ed4c0d5405ad861eddaa02afd4edd (patch)
tree6f0bf54ca7452621bc55548f2a2a2615b8975b54 /test/test_unbalanced.py
parent0b223ff883fd73601984a92c31cb70d4aded16e8 (diff)
[MRG] MM algorithms for UOT (#362)
* bugfix * update refs partial OT * fixes small typos in plot_partial_wass_and_gromov * fix small bugs in partial.py * update README * pep8 bugfix * modif doctest * fix bugtests * update on test_partial and test on the numerical precision on ot/partial * resolve merge pb * Delete partial.py * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update releases.md with new MM UOT algorithms Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r--test/test_unbalanced.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index db59504..02b3fc3 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -1,6 +1,7 @@
"""Tests for module Unbalanced OT with entropy regularization"""
# Author: Hicham Janati <hicham.janati@inria.fr>
+# Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
#
# License: MIT License
@@ -286,3 +287,52 @@ def test_implemented_methods(nx):
method=method)
barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
method=method)
+
+
+def test_mm_convergence(nx):
+ n = 100
+ rng = np.random.RandomState(42)
+ x = rng.randn(n, 2)
+ rng = np.random.RandomState(75)
+ y = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+ b = ot.utils.unif(n)
+
+ M = ot.dist(x, y)
+ M = M / M.max()
+ reg_m = 100
+ a, b, M = nx.from_numpy(a, b, M)
+
+ G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl',
+ verbose=True, log=True)
+ loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2(
+ a, b, M, reg_m, div='kl', verbose=True))
+ G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2',
+ verbose=False, log=True)
+
+ # check if the marginals come close to the true ones when large reg
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03)
+
+ # check if mm_unbalanced2 returns the correct loss
+ np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl,
+ atol=1e-5)
+
+ # check in case no histogram is provided
+ a_np, b_np = np.array([]), np.array([])
+ a, b = nx.from_numpy(a_np, b_np)
+
+ G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl')
+ G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2')
+ np.testing.assert_allclose(G_kl_null, G_kl)
+ np.testing.assert_allclose(G_l2_null, G_l2)
+
+ # test when G0 is given
+ G0 = ot.emd(a, b, M)
+ reg_m = 10000
+ G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0)
+ G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0)
+ np.testing.assert_allclose(G0, G_kl, atol=1e-05)
+ np.testing.assert_allclose(G0, G_l2, atol=1e-05)