summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-11-05 15:57:08 +0100
committerGitHub <noreply@github.com>2021-11-05 15:57:08 +0100
commit0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 (patch)
treeb0c0fbce0109ba460a67a6356dc0ff03e2b3c1d5 /test
parent0e431c203a66c6d48e6bb1efeda149460472a0f0 (diff)
[MRG] Tests with types/device on sliced/bregman/gromov functions (#303)
* First draft : making pytest use gpu for torch testing * bug solve * Revert "bug solve" This reverts commit 29b013abd162f8693128f26d8129186b79923609. * Revert "First draft : making pytest use gpu for torch testing" This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a. * sliced * sliced * ot 1dsolver * bregman * better print * jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them * gromov & entropic gromov
Diffstat (limited to 'test')
-rw-r--r--test/conftest.py25
-rw-r--r--test/test_1d_solver.py16
-rw-r--r--test/test_bregman.py45
-rw-r--r--test/test_gromov.py75
-rw-r--r--test/test_ot.py8
-rw-r--r--test/test_sliced.py44
6 files changed, 191 insertions, 22 deletions
diff --git a/test/conftest.py b/test/conftest.py
index 876b525..987d98e 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -11,6 +11,7 @@ import functools
if jax:
from jax.config import config
+ config.update("jax_enable_x64", True)
backend_list = get_backend_list()
@@ -18,16 +19,25 @@ backend_list = get_backend_list()
@pytest.fixture(params=backend_list)
def nx(request):
backend = request.param
- if backend.__name__ == "jax":
- config.update("jax_enable_x64", True)
yield backend
- if backend.__name__ == "jax":
- config.update("jax_enable_x64", False)
-
def skip_arg(arg, value, reason=None, getter=lambda x: x):
+ if isinstance(arg, tuple) or isinstance(arg, list):
+ n = len(arg)
+ else:
+ arg = (arg, )
+ n = 1
+ if n != 1 and (isinstance(value, tuple) or isinstance(value, list)):
+ pass
+ else:
+ value = (value, )
+ if isinstance(getter, tuple) or isinstance(value, list):
+ pass
+ else:
+ getter = [getter] * n
+
if reason is None:
reason = f"Param {arg} should be skipped for value {value}"
@@ -35,7 +45,10 @@ def skip_arg(arg, value, reason=None, getter=lambda x: x):
@functools.wraps(function)
def wrapped(*args, **kwargs):
- if arg in kwargs.keys() and getter(kwargs[arg]) == value:
+ if all(
+ arg[i] in kwargs.keys() and getter[i](kwargs[arg[i]]) == value[i]
+ for i in range(n)
+ ):
pytest.skip(reason)
return function(*args, **kwargs)
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 77b1234..cb85cb9 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -85,7 +85,6 @@ def test_wasserstein_1d(nx):
np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4)
-@pytest.mark.parametrize('nx', backend_list)
def test_wasserstein_1d_type_devices(nx):
rng = np.random.RandomState(0)
@@ -98,8 +97,7 @@ def test_wasserstein_1d_type_devices(nx):
rho_v /= rho_v.sum()
for tp in nx.__type_list__:
-
- print(tp.dtype)
+ print(nx.dtype_device(tp))
xb = nx.from_numpy(x, type_as=tp)
rho_ub = nx.from_numpy(rho_u, type_as=tp)
@@ -107,8 +105,7 @@ def test_wasserstein_1d_type_devices(nx):
res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
- if not str(nx) == 'numpy':
- assert res.dtype == xb.dtype
+ nx.assert_same_dtype_device(xb, res)
def test_emd_1d_emd2_1d():
@@ -162,17 +159,14 @@ def test_emd1d_type_devices(nx):
rho_v /= rho_v.sum()
for tp in nx.__type_list__:
-
- print(tp.dtype)
+ print(nx.dtype_device(tp))
xb = nx.from_numpy(x, type_as=tp)
rho_ub = nx.from_numpy(rho_u, type_as=tp)
rho_vb = nx.from_numpy(rho_v, type_as=tp)
emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
-
emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
- assert emd.dtype == xb.dtype
- if not str(nx) == 'numpy':
- assert emd2.dtype == xb.dtype
+ nx.assert_same_dtype_device(xb, emd)
+ nx.assert_same_dtype_device(xb, emd2)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index edfe9c3..830052d 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -278,6 +278,51 @@ def test_sinkhorn_variants(nx):
np.testing.assert_allclose(G0, G_green, atol=1e-5)
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_epsilon_scaling",
+ "greenkhorn",
+ "sinkhorn_log"])
+@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str)
+@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str)
+def test_sinkhorn_variants_dtype_device(nx, method):
+ n = 100
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ ub = nx.from_numpy(u, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
+
+ Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+
+ nx.assert_same_dtype_device(Mb, Gb)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"])
+def test_sinkhorn2_variants_dtype_device(nx, method):
+ n = 100
+
+ x = np.random.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ ub = nx.from_numpy(u, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
+
+ lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
+
+ nx.assert_same_dtype_device(Mb, lossb)
+
+
@pytest.skip_backend("jax")
def test_sinkhorn_variants_multi_b(nx):
# test sinkhorn
diff --git a/test/test_gromov.py b/test/test_gromov.py
index bcbcc3a..c4bc04c 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -75,6 +75,41 @@ def test_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_gromov_dtype_device(nx):
+ # setup
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ C1b = nx.from_numpy(C1, type_as=tp)
+ C2b = nx.from_numpy(C2, type_as=tp)
+ pb = nx.from_numpy(p, type_as=tp)
+ qb = nx.from_numpy(q, type_as=tp)
+
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+
def test_gromov2_gradients():
n_samples = 50 # nb samples
@@ -168,6 +203,46 @@ def test_entropic_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_entropic_gromov_dtype_device(nx):
+ # setup
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ C1b = nx.from_numpy(C1, type_as=tp)
+ C2b = nx.from_numpy(C2, type_as=tp)
+ pb = nx.from_numpy(p, type_as=tp)
+ qb = nx.from_numpy(q, type_as=tp)
+
+ Gb = ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ )
+ gw_valb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ )
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+
def test_pointwise_gromov(nx):
n_samples = 50 # nb samples
diff --git a/test/test_ot.py b/test/test_ot.py
index dc3930a..92f26a7 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -88,8 +88,7 @@ def test_emd_emd2_types_devices(nx):
M = ot.dist(x, y)
for tp in nx.__type_list__:
-
- print(tp.dtype)
+ print(nx.dtype_device(tp))
ab = nx.from_numpy(a, type_as=tp)
Mb = nx.from_numpy(M, type_as=tp)
@@ -98,9 +97,8 @@ def test_emd_emd2_types_devices(nx):
w = ot.emd2(ab, ab, Mb)
- assert Gb.dtype == Mb.dtype
- if not str(nx) == 'numpy':
- assert w.dtype == Mb.dtype
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
def test_emd2_gradients():
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 0bd74ec..245202c 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -139,6 +139,28 @@ def test_sliced_backend(nx):
assert np.allclose(val0, valb)
+def test_sliced_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+ yb = nx.from_numpy(y, type_as=tp)
+ Pb = nx.from_numpy(P, type_as=tp)
+
+ valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
+
+ nx.assert_same_dtype_device(xb, valb)
+
+
def test_max_sliced_backend(nx):
n = 100
@@ -167,3 +189,25 @@ def test_max_sliced_backend(nx):
valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb))
assert np.allclose(val0, valb)
+
+
+def test_max_sliced_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+ yb = nx.from_numpy(y, type_as=tp)
+ Pb = nx.from_numpy(P, type_as=tp)
+
+ valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
+
+ nx.assert_same_dtype_device(xb, valb)