diff options
Diffstat (limited to 'test/test_da.py')
-rw-r--r-- | test/test_da.py | 46 |
1 files changed, 44 insertions, 2 deletions
diff --git a/test/test_da.py b/test/test_da.py index 593dc53..3022721 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -326,8 +326,8 @@ def test_mapping_transport_class(): """test_mapping_transport """ - ns = 150 - nt = 200 + ns = 60 + nt = 120 Xs, ys = get_data_classif('3gauss', ns) Xt, yt = get_data_classif('3gauss2', nt) @@ -444,6 +444,48 @@ def test_mapping_transport_class(): assert len(otda.log_.keys()) != 0 +def test_linear_mapping(): + + ns = 150 + nt = 200 + + Xs, ys = get_data_classif('3gauss', ns) + Xt, yt = get_data_classif('3gauss2', nt) + + A, b = ot.da.OT_mapping_linear(Xs, Xt) + + Xst = Xs.dot(A) + b + + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) + + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + +def test_linear_mapping_class(): + + ns = 150 + nt = 200 + + Xs, ys = get_data_classif('3gauss', ns) + Xt, yt = get_data_classif('3gauss2', nt) + + otmap = ot.da.LinearTransport() + + otmap.fit(Xs=Xs, Xt=Xt) + assert hasattr(otmap, "A_") + assert hasattr(otmap, "B_") + assert hasattr(otmap, "A1_") + assert hasattr(otmap, "B1_") + + Xst = otmap.transform(Xs=Xs) + + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) + + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + def test_otda(): n_samples = 150 # nb samples |