summaryrefslogtreecommitdiff
path: root/test/conftest.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/conftest.py')
-rw-r--r--test/conftest.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/test/conftest.py b/test/conftest.py
new file mode 100644
index 0000000..876b525
--- /dev/null
+++ b/test/conftest.py
@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+
+# Configuration file for pytest
+
+# License: MIT License
+
+import pytest
+from ot.backend import jax
+from ot.backend import get_backend_list
+import functools
+
+if jax:
+ from jax.config import config
+
+backend_list = get_backend_list()
+
+
+@pytest.fixture(params=backend_list)
+def nx(request):
+ backend = request.param
+ if backend.__name__ == "jax":
+ config.update("jax_enable_x64", True)
+
+ yield backend
+
+ if backend.__name__ == "jax":
+ config.update("jax_enable_x64", False)
+
+
+def skip_arg(arg, value, reason=None, getter=lambda x: x):
+ if reason is None:
+ reason = f"Param {arg} should be skipped for value {value}"
+
+ def wrapper(function):
+
+ @functools.wraps(function)
+ def wrapped(*args, **kwargs):
+ if arg in kwargs.keys() and getter(kwargs[arg]) == value:
+ pytest.skip(reason)
+ return function(*args, **kwargs)
+
+ return wrapped
+
+ return wrapper
+
+
+def pytest_configure(config):
+ pytest.skip_arg = skip_arg
+ pytest.skip_backend = functools.partial(skip_arg, "nx", getter=str)