summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorhaoran010 <62598274+haoran010@users.noreply.github.com>2021-10-25 10:47:22 +0200
committerGitHub <noreply@github.com>2021-10-25 10:47:22 +0200
commit7af8c2147d61349f4d99ca33318a8a125e4569aa (patch)
tree5c08a89f2c998a6c1d734be28e4127130c1c1102 /test
parentd50d8145a5c0cf69d438b018cd5f1b914905e784 (diff)
[MRG] Regularization path for l2 UOT (#274)
* add reg path * debug examples and verify pep8 * pep8 and move the reg path examples in unbalanced folder Co-authored-by: haoran010 <haoran.wu@insa-rennes.fr> Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rw-r--r--test/test_regpath.py64
1 files changed, 64 insertions, 0 deletions
diff --git a/test/test_regpath.py b/test/test_regpath.py
new file mode 100644
index 0000000..967c27b
--- /dev/null
+++ b/test/test_regpath.py
@@ -0,0 +1,64 @@
+"""Tests for module regularization path"""
+
+# Author: Haoran Wu <haoran.wu@univ-ubs.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+
+
+def test_fully_relaxed_path():
+
+ n_source = 50 # nb source samples (gaussian)
+ n_target = 40 # nb target samples (gaussian)
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ np.random.seed(0)
+ xs = ot.datasets.make_2D_samples_gauss(n_source, mu, cov)
+ xt = ot.datasets.make_2D_samples_gauss(n_target, mu, cov)
+
+ # source and target distributions
+ a = ot.utils.unif(n_source)
+ b = ot.utils.unif(n_target)
+
+ # loss matrix
+ M = ot.dist(xs, xt)
+ M /= M.max()
+
+ t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8,
+ semi_relaxed=False)
+
+ G = t.reshape((n_source, n_target))
+ np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(b, G.sum(0), atol=1e-05)
+
+
+def test_semi_relaxed_path():
+
+ n_source = 50 # nb source samples (gaussian)
+ n_target = 40 # nb target samples (gaussian)
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ np.random.seed(0)
+ xs = ot.datasets.make_2D_samples_gauss(n_source, mu, cov)
+ xt = ot.datasets.make_2D_samples_gauss(n_target, mu, cov)
+
+ # source and target distributions
+ a = ot.utils.unif(n_source)
+ b = ot.utils.unif(n_target)
+
+ # loss matrix
+ M = ot.dist(xs, xt)
+ M /= M.max()
+
+ t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8,
+ semi_relaxed=True)
+
+ G = t.reshape((n_source, n_target))
+ np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(b, G.sum(0), atol=1e-10)