summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-26 11:38:17 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-26 11:38:17 +0200
commit67b011a2a6a0cb8dffbb7a2619875f0e0d79588c (patch)
tree4173fd1c541912195b2efff7f9e2ca42c90670f9 /test/test_da.py
parent68d74902bcd3d988fff8cb7713314063f04c0089 (diff)
numpy assert test_da
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py21
1 files changed, 11 insertions, 10 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 0d92b95..8df4795 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -6,9 +6,10 @@ import ot
# import pytest
-def test_OTDA():
+def test_otda():
- n = 150 # nb bins
+ n = 150 # nb samples
+ np.random.seed(0)
xs, ys = ot.datasets.get_data_classif('3gauss', n)
xt, yt = ot.datasets.get_data_classif('3gauss2', n)
@@ -21,8 +22,8 @@ def test_OTDA():
da_emd.interp() # interpolation of source samples
da_emd.predict(xs) # interpolation of source samples
- assert np.allclose(a, np.sum(da_emd.G, 1))
- assert np.allclose(b, np.sum(da_emd.G, 0))
+ np.testing.assert_allclose(a, np.sum(da_emd.G, 1))
+ np.testing.assert_allclose(b, np.sum(da_emd.G, 0))
# sinkhorn regularization
lambd = 1e-1
@@ -31,8 +32,8 @@ def test_OTDA():
da_entrop.interp()
da_entrop.predict(xs)
- assert np.allclose(a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
- assert np.allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3)
+ np.testing.assert_allclose(a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
+ np.testing.assert_allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3)
# non-convex Group lasso regularization
reg = 1e-1
@@ -42,8 +43,8 @@ def test_OTDA():
da_lpl1.interp()
da_lpl1.predict(xs)
- assert np.allclose(a, np.sum(da_lpl1.G, 1), rtol=1e-3, atol=1e-3)
- assert np.allclose(b, np.sum(da_lpl1.G, 0), rtol=1e-3, atol=1e-3)
+ np.testing.assert_allclose(a, np.sum(da_lpl1.G, 1), rtol=1e-3, atol=1e-3)
+ np.testing.assert_allclose(b, np.sum(da_lpl1.G, 0), rtol=1e-3, atol=1e-3)
# True Group lasso regularization
reg = 1e-1
@@ -53,8 +54,8 @@ def test_OTDA():
da_l1l2.interp()
da_l1l2.predict(xs)
- assert np.allclose(a, np.sum(da_l1l2.G, 1), rtol=1e-3, atol=1e-3)
- assert np.allclose(b, np.sum(da_l1l2.G, 0), rtol=1e-3, atol=1e-3)
+ np.testing.assert_allclose(a, np.sum(da_l1l2.G, 1), rtol=1e-3, atol=1e-3)
+ np.testing.assert_allclose(b, np.sum(da_l1l2.G, 0), rtol=1e-3, atol=1e-3)
# linear mapping
da_emd = ot.da.OTDA_mapping_linear() # init class