From 7af8c2147d61349f4d99ca33318a8a125e4569aa Mon Sep 17 00:00:00 2001 From: haoran010 <62598274+haoran010@users.noreply.github.com> Date: Mon, 25 Oct 2021 10:47:22 +0200 Subject: [MRG] Regularization path for l2 UOT (#274) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add reg path * debug examples and verify pep8 * pep8 and move the reg path examples in unbalanced folder Co-authored-by: haoran010 Co-authored-by: RĂ©mi Flamary --- test/test_regpath.py | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 test/test_regpath.py (limited to 'test/test_regpath.py') diff --git a/test/test_regpath.py b/test/test_regpath.py new file mode 100644 index 0000000..967c27b --- /dev/null +++ b/test/test_regpath.py @@ -0,0 +1,64 @@ +"""Tests for module regularization path""" + +# Author: Haoran Wu +# +# License: MIT License + +import numpy as np +import ot + + +def test_fully_relaxed_path(): + + n_source = 50 # nb source samples (gaussian) + n_target = 40 # nb target samples (gaussian) + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + np.random.seed(0) + xs = ot.datasets.make_2D_samples_gauss(n_source, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_target, mu, cov) + + # source and target distributions + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + # loss matrix + M = ot.dist(xs, xt) + M /= M.max() + + t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8, + semi_relaxed=False) + + G = t.reshape((n_source, n_target)) + np.testing.assert_allclose(a, G.sum(1), atol=1e-05) + np.testing.assert_allclose(b, G.sum(0), atol=1e-05) + + +def test_semi_relaxed_path(): + + n_source = 50 # nb source samples (gaussian) + n_target = 40 # nb target samples (gaussian) + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + np.random.seed(0) + xs = ot.datasets.make_2D_samples_gauss(n_source, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_target, mu, cov) + + # source and target distributions + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + # loss matrix + M = ot.dist(xs, xt) + M /= M.max() + + t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8, + semi_relaxed=True) + + G = t.reshape((n_source, n_target)) + np.testing.assert_allclose(a, G.sum(1), atol=1e-05) + np.testing.assert_allclose(b, G.sum(0), atol=1e-10) -- cgit v1.2.3