summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py79
1 files changed, 39 insertions, 40 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 4bf0ab1..c5f08d6 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -44,12 +44,32 @@ def test_class_jax_tf():
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
+@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport])
+def test_log_da(nx, class_to_test):
+
+ ns = 50
+ nt = 50
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
+ otda = class_to_test(log=True)
+
+ # test its computed
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(otda, "log_")
+
+
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
def test_sinkhorn_lpl1_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -136,7 +156,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
"""
ns = 50
- nt = 100
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -230,8 +250,8 @@ def test_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -323,8 +343,8 @@ def test_unbalanced_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -402,8 +422,8 @@ def test_emd_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -557,30 +577,9 @@ def test_mapping_transport_class_specific_seed(nx):
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
-def test_linear_mapping(nx):
- ns = 150
- nt = 200
-
- Xs, ys = make_data_classif('3gauss', ns)
- Xt, yt = make_data_classif('3gauss2', nt)
-
- Xsb, Xtb = nx.from_numpy(Xs, Xt)
-
- A, b = ot.da.OT_mapping_linear(Xsb, Xtb)
-
- Xst = nx.to_numpy(nx.dot(Xsb, A) + b)
-
- Ct = np.cov(Xt.T)
- Cst = np.cov(Xst.T)
-
- np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
-
-
-@pytest.skip_backend("jax")
-@pytest.skip_backend("tf")
def test_linear_mapping_class(nx):
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -609,9 +608,9 @@ def test_jcpot_transport_class(nx):
"""test_jcpot_transport
"""
- ns1 = 150
- ns2 = 150
- nt = 200
+ ns1 = 50
+ ns2 = 50
+ nt = 50
Xs1, ys1 = make_data_classif('3gauss', ns1)
Xs2, ys2 = make_data_classif('3gauss', ns2)
@@ -681,9 +680,9 @@ def test_jcpot_barycenter(nx):
"""test_jcpot_barycenter
"""
- ns1 = 150
- ns2 = 150
- nt = 200
+ ns1 = 50
+ ns2 = 50
+ nt = 50
sigma = 0.1
np.random.seed(1985)
@@ -713,8 +712,8 @@ def test_jcpot_barycenter(nx):
def test_emd_laplace_class(nx):
"""test_emd_laplace_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)