summaryrefslogtreecommitdiff
path: root/test/conftest.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-11-05 15:57:08 +0100
committerGitHub <noreply@github.com>2021-11-05 15:57:08 +0100
commit0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 (patch)
treeb0c0fbce0109ba460a67a6356dc0ff03e2b3c1d5 /test/conftest.py
parent0e431c203a66c6d48e6bb1efeda149460472a0f0 (diff)
[MRG] Tests with types/device on sliced/bregman/gromov functions (#303)
* First draft : making pytest use gpu for torch testing * bug solve * Revert "bug solve" This reverts commit 29b013abd162f8693128f26d8129186b79923609. * Revert "First draft : making pytest use gpu for torch testing" This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a. * sliced * sliced * ot 1dsolver * bregman * better print * jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them * gromov & entropic gromov
Diffstat (limited to 'test/conftest.py')
-rw-r--r--test/conftest.py25
1 files changed, 19 insertions, 6 deletions
diff --git a/test/conftest.py b/test/conftest.py
index 876b525..987d98e 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -11,6 +11,7 @@ import functools
if jax:
from jax.config import config
+ config.update("jax_enable_x64", True)
backend_list = get_backend_list()
@@ -18,16 +19,25 @@ backend_list = get_backend_list()
@pytest.fixture(params=backend_list)
def nx(request):
backend = request.param
- if backend.__name__ == "jax":
- config.update("jax_enable_x64", True)
yield backend
- if backend.__name__ == "jax":
- config.update("jax_enable_x64", False)
-
def skip_arg(arg, value, reason=None, getter=lambda x: x):
+ if isinstance(arg, tuple) or isinstance(arg, list):
+ n = len(arg)
+ else:
+ arg = (arg, )
+ n = 1
+ if n != 1 and (isinstance(value, tuple) or isinstance(value, list)):
+ pass
+ else:
+ value = (value, )
+ if isinstance(getter, tuple) or isinstance(value, list):
+ pass
+ else:
+ getter = [getter] * n
+
if reason is None:
reason = f"Param {arg} should be skipped for value {value}"
@@ -35,7 +45,10 @@ def skip_arg(arg, value, reason=None, getter=lambda x: x):
@functools.wraps(function)
def wrapped(*args, **kwargs):
- if arg in kwargs.keys() and getter(kwargs[arg]) == value:
+ if all(
+ arg[i] in kwargs.keys() and getter[i](kwargs[arg[i]]) == value[i]
+ for i in range(n)
+ ):
pytest.skip(reason)
return function(*args, **kwargs)