diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test_da.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/test/test_da.py b/test/test_da.py index aed9f61..93f7e83 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -5,13 +5,12 @@ # License: MIT License import numpy as np -import ot from numpy.testing.utils import assert_allclose, assert_equal + +import ot from ot.datasets import get_data_classif from ot.utils import unif -np.random.seed(42) - def test_sinkhorn_lpl1_transport_class(): """test_sinkhorn_transport @@ -325,3 +324,11 @@ def test_otda(): da_emd = ot.da.OTDA_mapping_kernel() # init class da_emd.fit(xs, xt, numItermax=10) # fit distributions da_emd.predict(xs) # interpolation of source samples + + +if __name__ == "__main__": + + test_sinkhorn_transport_class() + test_emd_transport_class() + test_sinkhorn_l1l2_transport_class() + test_sinkhorn_lpl1_transport_class() |