summaryrefslogtreecommitdiff
path: root/test/test_unbalanced.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r--test/test_unbalanced.py61
1 files changed, 43 insertions, 18 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index 02b3fc3..b76d738 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -5,6 +5,7 @@
#
# License: MIT License
+import itertools
import numpy as np
import ot
import pytest
@@ -289,32 +290,55 @@ def test_implemented_methods(nx):
method=method)
+@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2']))
+def test_lbfgsb_unbalanced(nx, reg_div, regm_div):
+
+ np.random.seed(42)
+
+ xs = np.random.randn(5, 2)
+ xt = np.random.randn(6, 2)
+
+ M = ot.dist(xs, xt)
+
+ a = ot.unif(5)
+ b = ot.unif(6)
+
+ G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False)
+
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+
+ Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False)
+
+ np.testing.assert_allclose(G, nx.to_numpy(Gb))
+
+
def test_mm_convergence(nx):
n = 100
rng = np.random.RandomState(42)
x = rng.randn(n, 2)
rng = np.random.RandomState(75)
y = rng.randn(n, 2)
- a = ot.utils.unif(n)
- b = ot.utils.unif(n)
+ a_np = ot.utils.unif(n)
+ b_np = ot.utils.unif(n)
M = ot.dist(x, y)
M = M / M.max()
reg_m = 100
- a, b, M = nx.from_numpy(a, b, M)
+ a, b, M = nx.from_numpy(a_np, b_np, M)
G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl',
- verbose=True, log=True)
- loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2(
- a, b, M, reg_m, div='kl', verbose=True))
+ verbose=False, log=True)
+ loss_kl = nx.to_numpy(
+ ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div='kl', verbose=True)
+ )
G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2',
verbose=False, log=True)
# check if the marginals come close to the true ones when large reg
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03)
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03)
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03)
- np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a_np, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b_np, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a_np, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b_np, atol=1e-03)
# check if mm_unbalanced2 returns the correct loss
np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl,
@@ -324,15 +348,16 @@ def test_mm_convergence(nx):
a_np, b_np = np.array([]), np.array([])
a, b = nx.from_numpy(a_np, b_np)
- G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl')
- G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2')
- np.testing.assert_allclose(G_kl_null, G_kl)
- np.testing.assert_allclose(G_l2_null, G_l2)
+ G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', verbose=False)
+ G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False)
+ np.testing.assert_allclose(nx.to_numpy(G_kl_null), nx.to_numpy(G_kl))
+ np.testing.assert_allclose(nx.to_numpy(G_l2_null), nx.to_numpy(G_l2))
# test when G0 is given
G0 = ot.emd(a, b, M)
+ G0_np = nx.to_numpy(G0)
reg_m = 10000
- G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0)
- G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0)
- np.testing.assert_allclose(G0, G_kl, atol=1e-05)
- np.testing.assert_allclose(G0, G_l2, atol=1e-05)
+ G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0, verbose=False)
+ G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0, verbose=False)
+ np.testing.assert_allclose(G0_np, nx.to_numpy(G_kl), atol=1e-05)
+ np.testing.assert_allclose(G0_np, nx.to_numpy(G_l2), atol=1e-05)