summaryrefslogtreecommitdiff
path: root/test/test_unbalanced.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r--test/test_unbalanced.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index fc40df0..b76d738 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -5,6 +5,7 @@
#
# License: MIT License
+import itertools
import numpy as np
import ot
import pytest
@@ -289,6 +290,28 @@ def test_implemented_methods(nx):
method=method)
+@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2']))
+def test_lbfgsb_unbalanced(nx, reg_div, regm_div):
+
+ np.random.seed(42)
+
+ xs = np.random.randn(5, 2)
+ xt = np.random.randn(6, 2)
+
+ M = ot.dist(xs, xt)
+
+ a = ot.unif(5)
+ b = ot.unif(6)
+
+ G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False)
+
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+
+ Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False)
+
+ np.testing.assert_allclose(G, nx.to_numpy(Gb))
+
+
def test_mm_convergence(nx):
n = 100
rng = np.random.RandomState(42)