summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_da.py62
1 files changed, 41 insertions, 21 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 4bf0ab1..138936f 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -44,12 +44,32 @@ def test_class_jax_tf():
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
+@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport])
+def test_log_da(nx, class_to_test):
+
+ ns = 50
+ nt = 50
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
+ otda = class_to_test(log=True)
+
+ # test its computed
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(otda, "log_")
+
+
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
def test_sinkhorn_lpl1_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -136,7 +156,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
"""
ns = 50
- nt = 100
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -230,8 +250,8 @@ def test_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -323,8 +343,8 @@ def test_unbalanced_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -402,8 +422,8 @@ def test_emd_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -558,8 +578,8 @@ def test_mapping_transport_class_specific_seed(nx):
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
def test_linear_mapping(nx):
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -579,8 +599,8 @@ def test_linear_mapping(nx):
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
def test_linear_mapping_class(nx):
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -609,9 +629,9 @@ def test_jcpot_transport_class(nx):
"""test_jcpot_transport
"""
- ns1 = 150
- ns2 = 150
- nt = 200
+ ns1 = 50
+ ns2 = 50
+ nt = 50
Xs1, ys1 = make_data_classif('3gauss', ns1)
Xs2, ys2 = make_data_classif('3gauss', ns2)
@@ -681,9 +701,9 @@ def test_jcpot_barycenter(nx):
"""test_jcpot_barycenter
"""
- ns1 = 150
- ns2 = 150
- nt = 200
+ ns1 = 50
+ ns2 = 50
+ nt = 50
sigma = 0.1
np.random.seed(1985)
@@ -713,8 +733,8 @@ def test_jcpot_barycenter(nx):
def test_emd_laplace_class(nx):
"""test_emd_laplace_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)