summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_da.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/test/test_da.py b/test/test_da.py
index a8c258a..958df7b 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -8,6 +8,7 @@ import numpy as np
from numpy.testing import assert_allclose, assert_equal
import ot
+from ot.bregman import jcpot_barycenter
from ot.datasets import make_data_classif
from ot.utils import unif
@@ -603,7 +604,6 @@ def test_jcpot_transport_class():
# test transform
transp_Xs = otda.transform(Xs=Xs)
[assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)]
- #assert_equal(transp_Xs.shape, Xs.shape)
Xs_new, _ = make_data_classif('3gauss', ns1 + 1)
transp_Xs_new = otda.transform(Xs_new)