From 8490196dcc982c492b7565e1ec4de5f75f006acf Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Tue, 6 Dec 2022 18:02:44 +0100 Subject: [MRG] Fix bug in regularized OTDA l1lp with log (#413) * correct bug in DA l1lp with log * better tests and speedup with smaller dataset size * remove jax for log test * remove trndorflow for log test * pep8! --- test/test_da.py | 62 ++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 41 insertions(+), 21 deletions(-) (limited to 'test/test_da.py') diff --git a/test/test_da.py b/test/test_da.py index 4bf0ab1..138936f 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -42,14 +42,34 @@ def test_class_jax_tf(): otda.fit(Xs=Xs, ys=ys, Xt=Xt) +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport]) +def test_log_da(nx, class_to_test): + + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + + otda = class_to_test(log=True) + + # test its computed + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + assert hasattr(otda, "log_") + + @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_sinkhorn_lpl1_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -136,7 +156,7 @@ def test_sinkhorn_l1l2_transport_class(nx): """ ns = 50 - nt = 100 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -230,8 +250,8 @@ def test_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -323,8 +343,8 @@ def test_unbalanced_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -402,8 +422,8 @@ def test_emd_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -558,8 +578,8 @@ def test_mapping_transport_class_specific_seed(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_linear_mapping(nx): - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -579,8 +599,8 @@ def test_linear_mapping(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_linear_mapping_class(nx): - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -609,9 +629,9 @@ def test_jcpot_transport_class(nx): """test_jcpot_transport """ - ns1 = 150 - ns2 = 150 - nt = 200 + ns1 = 50 + ns2 = 50 + nt = 50 Xs1, ys1 = make_data_classif('3gauss', ns1) Xs2, ys2 = make_data_classif('3gauss', ns2) @@ -681,9 +701,9 @@ def test_jcpot_barycenter(nx): """test_jcpot_barycenter """ - ns1 = 150 - ns2 = 150 - nt = 200 + ns1 = 50 + ns2 = 50 + nt = 50 sigma = 0.1 np.random.seed(1985) @@ -713,8 +733,8 @@ def test_jcpot_barycenter(nx): def test_emd_laplace_class(nx): """test_emd_laplace_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) -- cgit v1.2.3