From 67b011a2a6a0cb8dffbb7a2619875f0e0d79588c Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 26 Jul 2017 11:38:17 +0200 Subject: numpy assert test_da --- test/test_da.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/test/test_da.py b/test/test_da.py index 0d92b95..8df4795 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -6,9 +6,10 @@ import ot # import pytest -def test_OTDA(): +def test_otda(): - n = 150 # nb bins + n = 150 # nb samples + np.random.seed(0) xs, ys = ot.datasets.get_data_classif('3gauss', n) xt, yt = ot.datasets.get_data_classif('3gauss2', n) @@ -21,8 +22,8 @@ def test_OTDA(): da_emd.interp() # interpolation of source samples da_emd.predict(xs) # interpolation of source samples - assert np.allclose(a, np.sum(da_emd.G, 1)) - assert np.allclose(b, np.sum(da_emd.G, 0)) + np.testing.assert_allclose(a, np.sum(da_emd.G, 1)) + np.testing.assert_allclose(b, np.sum(da_emd.G, 0)) # sinkhorn regularization lambd = 1e-1 @@ -31,8 +32,8 @@ def test_OTDA(): da_entrop.interp() da_entrop.predict(xs) - assert np.allclose(a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3) - assert np.allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3) # non-convex Group lasso regularization reg = 1e-1 @@ -42,8 +43,8 @@ def test_OTDA(): da_lpl1.interp() da_lpl1.predict(xs) - assert np.allclose(a, np.sum(da_lpl1.G, 1), rtol=1e-3, atol=1e-3) - assert np.allclose(b, np.sum(da_lpl1.G, 0), rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(a, np.sum(da_lpl1.G, 1), rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(b, np.sum(da_lpl1.G, 0), rtol=1e-3, atol=1e-3) # True Group lasso regularization reg = 1e-1 @@ -53,8 +54,8 @@ def test_OTDA(): da_l1l2.interp() da_l1l2.predict(xs) - assert np.allclose(a, np.sum(da_l1l2.G, 1), rtol=1e-3, atol=1e-3) - assert np.allclose(b, np.sum(da_l1l2.G, 0), rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(a, np.sum(da_l1l2.G, 1), rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(b, np.sum(da_l1l2.G, 0), rtol=1e-3, atol=1e-3) # linear mapping da_emd = ot.da.OTDA_mapping_linear() # init class -- cgit v1.2.3