diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-03-21 08:29:50 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-03-21 08:29:50 +0100 |
commit | 6fdf5de8fa27fa16d6b8910fe96eb67b7761aa0e (patch) | |
tree | 55b8d6e463b131134b4aa9d4e3013bbb77811da6 /test/test_da.py | |
parent | 287c659ad35f5036ba2687caf73009ef455c7239 (diff) |
add linear mapping test + autopep8
Diffstat (limited to 'test/test_da.py')
-rw-r--r-- | test/test_da.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/test/test_da.py b/test/test_da.py index 593dc53..7b63daf 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -444,6 +444,24 @@ def test_mapping_transport_class(): assert len(otda.log_.keys()) != 0 +def test_linear_mapping(): + + ns = 150 + nt = 200 + + Xs, ys = get_data_classif('3gauss', ns) + Xt, yt = get_data_classif('3gauss2', nt) + + A, b = ot.da.OT_mapping_linear(Xs, Xt) + + Xst = Xs.dot(A) + b + + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) + + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + def test_otda(): n_samples = 150 # nb samples |