From 0411ea22a96f9c22af30156b45c16ef39ffb520d Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Thu, 15 Dec 2022 09:28:01 +0100 Subject: [MRG] New API for OT solver (with pre-computed ground cost matrix) (#388) * new API for OT solver * use itertools for product of parameters * add tests for result class * add tests for result class * add tests for result class last time? * add sinkhorn * make partial OT bckend compatible * add TV as unbalanced flavor * better tests * make smoth backend compatible and add l2 tregularizatio to solve * add reularizedd unbalanced * add test for more complex attibutes * add test for more complex attibutes * add generic unbalaned solver and implement it for ot.solve * add entropy to possible regularization * star of documentation for ot.solv * weird new pep8 * documenttaion for function ot.solve done * pep8 * Update ot/solvers.py Co-authored-by: Alexandre Gramfort * update release file * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * add test NotImplemented * pep8 * pep8gcmp pep8! * compute kl in backend * debug tensorflow kl backend Co-authored-by: Alexandre Gramfort --- test/test_unbalanced.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) (limited to 'test/test_unbalanced.py') diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index fc40df0..b76d738 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -5,6 +5,7 @@ # # License: MIT License +import itertools import numpy as np import ot import pytest @@ -289,6 +290,28 @@ def test_implemented_methods(nx): method=method) +@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) +def test_lbfgsb_unbalanced(nx, reg_div, regm_div): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + + M = ot.dist(xs, xt) + + a = ot.unif(5) + b = ot.unif(6) + + G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False) + + ab, bb, Mb = nx.from_numpy(a, b, M) + + Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False) + + np.testing.assert_allclose(G, nx.to_numpy(Gb)) + + def test_mm_convergence(nx): n = 100 rng = np.random.RandomState(42) -- cgit v1.2.3