diff options
Diffstat (limited to 'test/test_da.py')
-rw-r--r-- | test/test_da.py | 79 |
1 files changed, 39 insertions, 40 deletions
diff --git a/test/test_da.py b/test/test_da.py index 4bf0ab1..c5f08d6 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -44,12 +44,32 @@ def test_class_jax_tf(): @pytest.skip_backend("jax") @pytest.skip_backend("tf") +@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport]) +def test_log_da(nx, class_to_test): + + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + + otda = class_to_test(log=True) + + # test its computed + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + assert hasattr(otda, "log_") + + +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") def test_sinkhorn_lpl1_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -136,7 +156,7 @@ def test_sinkhorn_l1l2_transport_class(nx): """ ns = 50 - nt = 100 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -230,8 +250,8 @@ def test_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -323,8 +343,8 @@ def test_unbalanced_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -402,8 +422,8 @@ def test_emd_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -557,30 +577,9 @@ def test_mapping_transport_class_specific_seed(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") -def test_linear_mapping(nx): - ns = 150 - nt = 200 - - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) - - Xsb, Xtb = nx.from_numpy(Xs, Xt) - - A, b = ot.da.OT_mapping_linear(Xsb, Xtb) - - Xst = nx.to_numpy(nx.dot(Xsb, A) + b) - - Ct = np.cov(Xt.T) - Cst = np.cov(Xst.T) - - np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) - - -@pytest.skip_backend("jax") -@pytest.skip_backend("tf") def test_linear_mapping_class(nx): - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -609,9 +608,9 @@ def test_jcpot_transport_class(nx): """test_jcpot_transport """ - ns1 = 150 - ns2 = 150 - nt = 200 + ns1 = 50 + ns2 = 50 + nt = 50 Xs1, ys1 = make_data_classif('3gauss', ns1) Xs2, ys2 = make_data_classif('3gauss', ns2) @@ -681,9 +680,9 @@ def test_jcpot_barycenter(nx): """test_jcpot_barycenter """ - ns1 = 150 - ns2 = 150 - nt = 200 + ns1 = 50 + ns2 = 50 + nt = 50 sigma = 0.1 np.random.seed(1985) @@ -713,8 +712,8 @@ def test_jcpot_barycenter(nx): def test_emd_laplace_class(nx): """test_emd_laplace_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) |