diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2017-07-26 15:25:53 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-07-26 15:25:53 +0200 |
commit | 7638d019b43e52d17600cac653939e7cd807478c (patch) | |
tree | a77441ddf844d953a3e797a3fab2a1ee3b85bf34 /test/test_da.py | |
parent | 1cf304cee298e2752ce29c83e5201f593722c3af (diff) | |
parent | 838550ead9cc8a66d9b9c1212c5dda2457dc59a5 (diff) |
Merge pull request #19 from rflamary/pytest
Pytest with 89% coverage
Fixes #19
Diffstat (limited to 'test/test_da.py')
-rw-r--r-- | test/test_da.py | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/test/test_da.py b/test/test_da.py new file mode 100644 index 0000000..dfba83f --- /dev/null +++ b/test/test_da.py @@ -0,0 +1,70 @@ +"""Tests for module da on Domain Adaptation """ + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import numpy as np +import ot + + +def test_otda(): + + n_samples = 150 # nb samples + np.random.seed(0) + + xs, ys = ot.datasets.get_data_classif('3gauss', n_samples) + xt, yt = ot.datasets.get_data_classif('3gauss2', n_samples) + + a, b = ot.unif(n_samples), ot.unif(n_samples) + + # LP problem + da_emd = ot.da.OTDA() # init class + da_emd.fit(xs, xt) # fit distributions + da_emd.interp() # interpolation of source samples + da_emd.predict(xs) # interpolation of source samples + + np.testing.assert_allclose(a, np.sum(da_emd.G, 1)) + np.testing.assert_allclose(b, np.sum(da_emd.G, 0)) + + # sinkhorn regularization + lambd = 1e-1 + da_entrop = ot.da.OTDA_sinkhorn() + da_entrop.fit(xs, xt, reg=lambd) + da_entrop.interp() + da_entrop.predict(xs) + + np.testing.assert_allclose(a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3) + + # non-convex Group lasso regularization + reg = 1e-1 + eta = 1e0 + da_lpl1 = ot.da.OTDA_lpl1() + da_lpl1.fit(xs, ys, xt, reg=reg, eta=eta) + da_lpl1.interp() + da_lpl1.predict(xs) + + np.testing.assert_allclose(a, np.sum(da_lpl1.G, 1), rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(b, np.sum(da_lpl1.G, 0), rtol=1e-3, atol=1e-3) + + # True Group lasso regularization + reg = 1e-1 + eta = 2e0 + da_l1l2 = ot.da.OTDA_l1l2() + da_l1l2.fit(xs, ys, xt, reg=reg, eta=eta, numItermax=20, verbose=True) + da_l1l2.interp() + da_l1l2.predict(xs) + + np.testing.assert_allclose(a, np.sum(da_l1l2.G, 1), rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(b, np.sum(da_l1l2.G, 0), rtol=1e-3, atol=1e-3) + + # linear mapping + da_emd = ot.da.OTDA_mapping_linear() # init class + da_emd.fit(xs, xt, numItermax=10) # fit distributions + da_emd.predict(xs) # interpolation of source samples + + # nonlinear mapping + da_emd = ot.da.OTDA_mapping_kernel() # init class + da_emd.fit(xs, xt, numItermax=10) # fit distributions + da_emd.predict(xs) # interpolation of source samples |