summaryrefslogtreecommitdiff
path: root/test/conftest.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-10-25 11:36:21 +0200
committerGitHub <noreply@github.com>2021-10-25 11:36:21 +0200
commit7a65086dd340265d0223eb8ffb5c9a5152a82dff (patch)
tree300f4a1cd645516fba1e440691fe48830d781b5c /test/conftest.py
parent7af8c2147d61349f4d99ca33318a8a125e4569aa (diff)
[MRG] Bregman backend (#280)
* Bregman * Resolve conflicts * Bug solve * Bregman updated for JAX compatibility * Tests coherence between backend improved * No longer enforcing 64 bits operations on Jax except for tests * Now using mixtures, to make backend dependent tests with less code * Better test skipping code * Pep8 + test optimizations * redundancy removed * Docs * Typo corrected * Typo * Typo * Docs * Docs * pep8 * Backend docs * Prettier docs * Mistake corrected * small changes * Better wording Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/conftest.py')
-rw-r--r--test/conftest.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/test/conftest.py b/test/conftest.py
new file mode 100644
index 0000000..876b525
--- /dev/null
+++ b/test/conftest.py
@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+
+# Configuration file for pytest
+
+# License: MIT License
+
+import pytest
+from ot.backend import jax
+from ot.backend import get_backend_list
+import functools
+
+if jax:
+ from jax.config import config
+
+backend_list = get_backend_list()
+
+
+@pytest.fixture(params=backend_list)
+def nx(request):
+ backend = request.param
+ if backend.__name__ == "jax":
+ config.update("jax_enable_x64", True)
+
+ yield backend
+
+ if backend.__name__ == "jax":
+ config.update("jax_enable_x64", False)
+
+
+def skip_arg(arg, value, reason=None, getter=lambda x: x):
+ if reason is None:
+ reason = f"Param {arg} should be skipped for value {value}"
+
+ def wrapper(function):
+
+ @functools.wraps(function)
+ def wrapped(*args, **kwargs):
+ if arg in kwargs.keys() and getter(kwargs[arg]) == value:
+ pytest.skip(reason)
+ return function(*args, **kwargs)
+
+ return wrapped
+
+ return wrapper
+
+
+def pytest_configure(config):
+ pytest.skip_arg = skip_arg
+ pytest.skip_backend = functools.partial(skip_arg, "nx", getter=str)