summaryrefslogtreecommitdiff
path: root/test/conftest.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/conftest.py')
-rw-r--r--test/conftest.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/test/conftest.py b/test/conftest.py
index 987d98e..c0db8ab 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -5,7 +5,7 @@
# License: MIT License
import pytest
-from ot.backend import jax
+from ot.backend import jax, tf
from ot.backend import get_backend_list
import functools
@@ -13,6 +13,10 @@ if jax:
from jax.config import config
config.update("jax_enable_x64", True)
+if tf:
+ from tensorflow.python.ops.numpy_ops import np_config
+ np_config.enable_numpy_behavior()
+
backend_list = get_backend_list()
@@ -24,16 +28,16 @@ def nx(request):
def skip_arg(arg, value, reason=None, getter=lambda x: x):
- if isinstance(arg, tuple) or isinstance(arg, list):
+ if isinstance(arg, (tuple, list)):
n = len(arg)
else:
arg = (arg, )
n = 1
- if n != 1 and (isinstance(value, tuple) or isinstance(value, list)):
+ if n != 1 and isinstance(value, (tuple, list)):
pass
else:
value = (value, )
- if isinstance(getter, tuple) or isinstance(value, list):
+ if isinstance(getter, (tuple, list)):
pass
else:
getter = [getter] * n