summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py46
1 files changed, 44 insertions, 2 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 593dc53..3022721 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -326,8 +326,8 @@ def test_mapping_transport_class():
"""test_mapping_transport
"""
- ns = 150
- nt = 200
+ ns = 60
+ nt = 120
Xs, ys = get_data_classif('3gauss', ns)
Xt, yt = get_data_classif('3gauss2', nt)
@@ -444,6 +444,48 @@ def test_mapping_transport_class():
assert len(otda.log_.keys()) != 0
+def test_linear_mapping():
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = get_data_classif('3gauss', ns)
+ Xt, yt = get_data_classif('3gauss2', nt)
+
+ A, b = ot.da.OT_mapping_linear(Xs, Xt)
+
+ Xst = Xs.dot(A) + b
+
+ Ct = np.cov(Xt.T)
+ Cst = np.cov(Xst.T)
+
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
+def test_linear_mapping_class():
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = get_data_classif('3gauss', ns)
+ Xt, yt = get_data_classif('3gauss2', nt)
+
+ otmap = ot.da.LinearTransport()
+
+ otmap.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otmap, "A_")
+ assert hasattr(otmap, "B_")
+ assert hasattr(otmap, "A1_")
+ assert hasattr(otmap, "B1_")
+
+ Xst = otmap.transform(Xs=Xs)
+
+ Ct = np.cov(Xt.T)
+ Cst = np.cov(Xst.T)
+
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
def test_otda():
n_samples = 150 # nb samples