diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-10-25 17:35:36 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-25 17:35:36 +0200 |
commit | 76450dddf8dd62b9714b72e99ae075516246d433 (patch) | |
tree | 67de8de1c185cc8e7fc33a1fc0613015824d1fbb /test/test_ot.py | |
parent | 7a65086dd340265d0223eb8ffb5c9a5152a82dff (diff) |
[MRG] Backend for optim (#282)
* Backend for optim
* Bug solve
* Doc update
* backend tests now with fixture
* Unused imports removed
* Docs
* Docs
* Docs
* Outer product backend docs
* Prettier docs
* Pep8
* Mistakes corrected
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 6 |
1 files changed, 1 insertions, 5 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index 3e953dc..4dfc510 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,9 +12,7 @@ from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss -from ot.backend import get_backend_list, torch - -backend_list = get_backend_list() +from ot.backend import torch def test_emd_dimension_and_mass_mismatch(): @@ -37,7 +35,6 @@ def test_emd_dimension_and_mass_mismatch(): np.testing.assert_raises(AssertionError, ot.emd, a, b, M) -@pytest.mark.parametrize('nx', backend_list) def test_emd_backends(nx): n_samples = 100 n_features = 2 @@ -59,7 +56,6 @@ def test_emd_backends(nx): np.allclose(G, nx.to_numpy(Gb)) -@pytest.mark.parametrize('nx', backend_list) def test_emd2_backends(nx): n_samples = 100 n_features = 2 |