summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-03-21 09:03:58 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-03-21 09:03:58 +0100
commit5efdf008865ea347775708b637d933e048d663ec (patch)
treea042b112c555a686454c4540dbe5659205e623cf /test/test_da.py
parent6fdf5de8fa27fa16d6b8910fe96eb67b7761aa0e (diff)
add test linear mapping class
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 7b63daf..a9d6d34 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -462,6 +462,30 @@ def test_linear_mapping():
np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+def test_linear_mapping_class():
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = get_data_classif('3gauss', ns)
+ Xt, yt = get_data_classif('3gauss2', nt)
+
+ otmap = ot.da.LinearTransport()
+
+ otmap.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otmap, "A_")
+ assert hasattr(otmap, "B_")
+ assert hasattr(otmap, "A1_")
+ assert hasattr(otmap, "B1_")
+
+ Xst = otmap.transform(Xs=Xs)
+
+ Ct = np.cov(Xt.T)
+ Cst = np.cov(Xst.T)
+
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
def test_otda():
n_samples = 150 # nb samples