diff options
Diffstat (limited to 'test/test_da.py')
-rw-r--r-- | test/test_da.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/test/test_da.py b/test/test_da.py index a8c258a..958df7b 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -8,6 +8,7 @@ import numpy as np from numpy.testing import assert_allclose, assert_equal import ot +from ot.bregman import jcpot_barycenter from ot.datasets import make_data_classif from ot.utils import unif @@ -603,7 +604,6 @@ def test_jcpot_transport_class(): # test transform transp_Xs = otda.transform(Xs=Xs) [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] - #assert_equal(transp_Xs.shape, Xs.shape) Xs_new, _ = make_data_classif('3gauss', ns1 + 1) transp_Xs_new = otda.transform(Xs_new) |