diff options
Diffstat (limited to 'test/conftest.py')
-rw-r--r-- | test/conftest.py | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..987d98e --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- + +# Configuration file for pytest + +# License: MIT License + +import pytest +from ot.backend import jax +from ot.backend import get_backend_list +import functools + +if jax: + from jax.config import config + config.update("jax_enable_x64", True) + +backend_list = get_backend_list() + + +@pytest.fixture(params=backend_list) +def nx(request): + backend = request.param + + yield backend + + +def skip_arg(arg, value, reason=None, getter=lambda x: x): + if isinstance(arg, tuple) or isinstance(arg, list): + n = len(arg) + else: + arg = (arg, ) + n = 1 + if n != 1 and (isinstance(value, tuple) or isinstance(value, list)): + pass + else: + value = (value, ) + if isinstance(getter, tuple) or isinstance(value, list): + pass + else: + getter = [getter] * n + + if reason is None: + reason = f"Param {arg} should be skipped for value {value}" + + def wrapper(function): + + @functools.wraps(function) + def wrapped(*args, **kwargs): + if all( + arg[i] in kwargs.keys() and getter[i](kwargs[arg[i]]) == value[i] + for i in range(n) + ): + pytest.skip(reason) + return function(*args, **kwargs) + + return wrapped + + return wrapper + + +def pytest_configure(config): + pytest.skip_arg = skip_arg + pytest.skip_backend = functools.partial(skip_arg, "nx", getter=str) |