summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-10-25 17:35:36 +0200
committerGitHub <noreply@github.com>2021-10-25 17:35:36 +0200
commit76450dddf8dd62b9714b72e99ae075516246d433 (patch)
tree67de8de1c185cc8e7fc33a1fc0613015824d1fbb /test/test_ot.py
parent7a65086dd340265d0223eb8ffb5c9a5152a82dff (diff)
[MRG] Backend for optim (#282)
* Backend for optim * Bug solve * Doc update * backend tests now with fixture * Unused imports removed * Docs * Docs * Docs * Outer product backend docs * Prettier docs * Pep8 * Mistakes corrected Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py6
1 files changed, 1 insertions, 5 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 3e953dc..4dfc510 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -12,9 +12,7 @@ from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
-from ot.backend import get_backend_list, torch
-
-backend_list = get_backend_list()
+from ot.backend import torch
def test_emd_dimension_and_mass_mismatch():
@@ -37,7 +35,6 @@ def test_emd_dimension_and_mass_mismatch():
np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
-@pytest.mark.parametrize('nx', backend_list)
def test_emd_backends(nx):
n_samples = 100
n_features = 2
@@ -59,7 +56,6 @@ def test_emd_backends(nx):
np.allclose(G, nx.to_numpy(Gb))
-@pytest.mark.parametrize('nx', backend_list)
def test_emd2_backends(nx):
n_samples = 100
n_features = 2