summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/test/test_da.py b/test/test_da.py
index aed9f61..93f7e83 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -5,13 +5,12 @@
# License: MIT License
import numpy as np
-import ot
from numpy.testing.utils import assert_allclose, assert_equal
+
+import ot
from ot.datasets import get_data_classif
from ot.utils import unif
-np.random.seed(42)
-
def test_sinkhorn_lpl1_transport_class():
"""test_sinkhorn_transport
@@ -325,3 +324,11 @@ def test_otda():
da_emd = ot.da.OTDA_mapping_kernel() # init class
da_emd.fit(xs, xt, numItermax=10) # fit distributions
da_emd.predict(xs) # interpolation of source samples
+
+
+if __name__ == "__main__":
+
+ test_sinkhorn_transport_class()
+ test_emd_transport_class()
+ test_sinkhorn_l1l2_transport_class()
+ test_sinkhorn_lpl1_transport_class()