summaryrefslogtreecommitdiff
path: root/test/conftest.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/conftest.py')
-rw-r--r--test/conftest.py62
1 files changed, 62 insertions, 0 deletions
diff --git a/test/conftest.py b/test/conftest.py
new file mode 100644
index 0000000..987d98e
--- /dev/null
+++ b/test/conftest.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+
+# Configuration file for pytest
+
+# License: MIT License
+
+import pytest
+from ot.backend import jax
+from ot.backend import get_backend_list
+import functools
+
+if jax:
+ from jax.config import config
+ config.update("jax_enable_x64", True)
+
+backend_list = get_backend_list()
+
+
+@pytest.fixture(params=backend_list)
+def nx(request):
+ backend = request.param
+
+ yield backend
+
+
+def skip_arg(arg, value, reason=None, getter=lambda x: x):
+ if isinstance(arg, tuple) or isinstance(arg, list):
+ n = len(arg)
+ else:
+ arg = (arg, )
+ n = 1
+ if n != 1 and (isinstance(value, tuple) or isinstance(value, list)):
+ pass
+ else:
+ value = (value, )
+ if isinstance(getter, tuple) or isinstance(value, list):
+ pass
+ else:
+ getter = [getter] * n
+
+ if reason is None:
+ reason = f"Param {arg} should be skipped for value {value}"
+
+ def wrapper(function):
+
+ @functools.wraps(function)
+ def wrapped(*args, **kwargs):
+ if all(
+ arg[i] in kwargs.keys() and getter[i](kwargs[arg[i]]) == value[i]
+ for i in range(n)
+ ):
+ pytest.skip(reason)
+ return function(*args, **kwargs)
+
+ return wrapped
+
+ return wrapper
+
+
+def pytest_configure(config):
+ pytest.skip_arg = skip_arg
+ pytest.skip_backend = functools.partial(skip_arg, "nx", getter=str)