diff options
Diffstat (limited to 'test/test_da.py')
-rw-r--r-- | test/test_da.py | 21 |
1 files changed, 0 insertions, 21 deletions
diff --git a/test/test_da.py b/test/test_da.py index 138936f..c5f08d6 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -577,27 +577,6 @@ def test_mapping_transport_class_specific_seed(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") -def test_linear_mapping(nx): - ns = 50 - nt = 50 - - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) - - Xsb, Xtb = nx.from_numpy(Xs, Xt) - - A, b = ot.da.OT_mapping_linear(Xsb, Xtb) - - Xst = nx.to_numpy(nx.dot(Xsb, A) + b) - - Ct = np.cov(Xt.T) - Cst = np.cov(Xst.T) - - np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) - - -@pytest.skip_backend("jax") -@pytest.skip_backend("tf") def test_linear_mapping_class(nx): ns = 50 nt = 50 |