From 0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Fri, 5 Nov 2021 15:57:08 +0100 Subject: [MRG] Tests with types/device on sliced/bregman/gromov functions (#303) * First draft : making pytest use gpu for torch testing * bug solve * Revert "bug solve" This reverts commit 29b013abd162f8693128f26d8129186b79923609. * Revert "First draft : making pytest use gpu for torch testing" This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a. * sliced * sliced * ot 1dsolver * bregman * better print * jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them * gromov & entropic gromov --- test/conftest.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) (limited to 'test/conftest.py') diff --git a/test/conftest.py b/test/conftest.py index 876b525..987d98e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -11,6 +11,7 @@ import functools if jax: from jax.config import config + config.update("jax_enable_x64", True) backend_list = get_backend_list() @@ -18,16 +19,25 @@ backend_list = get_backend_list() @pytest.fixture(params=backend_list) def nx(request): backend = request.param - if backend.__name__ == "jax": - config.update("jax_enable_x64", True) yield backend - if backend.__name__ == "jax": - config.update("jax_enable_x64", False) - def skip_arg(arg, value, reason=None, getter=lambda x: x): + if isinstance(arg, tuple) or isinstance(arg, list): + n = len(arg) + else: + arg = (arg, ) + n = 1 + if n != 1 and (isinstance(value, tuple) or isinstance(value, list)): + pass + else: + value = (value, ) + if isinstance(getter, tuple) or isinstance(value, list): + pass + else: + getter = [getter] * n + if reason is None: reason = f"Param {arg} should be skipped for value {value}" @@ -35,7 +45,10 @@ def skip_arg(arg, value, reason=None, getter=lambda x: x): @functools.wraps(function) def wrapped(*args, **kwargs): - if arg in kwargs.keys() and getter(kwargs[arg]) == value: + if all( + arg[i] in kwargs.keys() and getter[i](kwargs[arg[i]]) == value[i] + for i in range(n) + ): pytest.skip(reason) return function(*args, **kwargs) -- cgit v1.2.3