diff options
-rw-r--r-- | ot/backend.py | 59 | ||||
-rw-r--r-- | ot/sliced.py | 8 | ||||
-rw-r--r-- | test/conftest.py | 25 | ||||
-rw-r--r-- | test/test_1d_solver.py | 16 | ||||
-rw-r--r-- | test/test_bregman.py | 45 | ||||
-rw-r--r-- | test/test_gromov.py | 75 | ||||
-rw-r--r-- | test/test_ot.py | 8 | ||||
-rw-r--r-- | test/test_sliced.py | 44 |
8 files changed, 247 insertions, 33 deletions
diff --git a/ot/backend.py b/ot/backend.py index 55e10d3..a044f84 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -653,6 +653,18 @@ class Backend(): """ raise NotImplementedError() + def dtype_device(self, a): + r""" + Returns the dtype and the device of the given tensor. + """ + raise NotImplementedError() + + def assert_same_dtype_device(self, a, b): + r""" + Checks whether or not the two given inputs have the same dtype as well as the same device + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -880,6 +892,16 @@ class NumpyBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + def dtype_device(self, a): + if hasattr(a, "dtype"): + return a.dtype, "cpu" + else: + return type(a), "cpu" + + def assert_same_dtype_device(self, a, b): + # numpy has implicit type conversion so we automatically validate the test + pass + class JaxBackend(Backend): """ @@ -899,17 +921,20 @@ class JaxBackend(Backend): self.rng_ = jax.random.PRNGKey(42) for d in jax.devices(): - self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d), - jax.device_put(jnp.array(1, dtype=np.float64), d)] + self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d), + jax.device_put(jnp.array(1, dtype=jnp.float64), d)] def to_numpy(self, a): return np.array(a) + def _change_device(self, a, type_as): + return jax.device_put(a, type_as.device_buffer.device()) + def from_numpy(self, a, type_as=None): if type_as is None: return jnp.array(a) else: - return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device()) + return self._change_device(jnp.array(a).astype(type_as.dtype), type_as) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree @@ -928,13 +953,13 @@ class JaxBackend(Backend): if type_as is None: return jnp.zeros(shape) else: - return jnp.zeros(shape, dtype=type_as.dtype) + return self._change_device(jnp.zeros(shape, dtype=type_as.dtype), type_as) def ones(self, shape, type_as=None): if type_as is None: return jnp.ones(shape) else: - return jnp.ones(shape, dtype=type_as.dtype) + return self._change_device(jnp.ones(shape, dtype=type_as.dtype), type_as) def arange(self, stop, start=0, step=1, type_as=None): return jnp.arange(start, stop, step) @@ -943,13 +968,13 @@ class JaxBackend(Backend): if type_as is None: return jnp.full(shape, fill_value) else: - return jnp.full(shape, fill_value, dtype=type_as.dtype) + return self._change_device(jnp.full(shape, fill_value, dtype=type_as.dtype), type_as) def eye(self, N, M=None, type_as=None): if type_as is None: return jnp.eye(N, M) else: - return jnp.eye(N, M, dtype=type_as.dtype) + return self._change_device(jnp.eye(N, M, dtype=type_as.dtype), type_as) def sum(self, a, axis=None, keepdims=False): return jnp.sum(a, axis, keepdims=keepdims) @@ -1127,6 +1152,16 @@ class JaxBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + def dtype_device(self, a): + return a.dtype, a.device_buffer.device() + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + class TorchBackend(Backend): """ @@ -1455,3 +1490,13 @@ class TorchBackend(Backend): def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + def dtype_device(self, a): + return a.dtype, a.device + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" diff --git a/ot/sliced.py b/ot/sliced.py index 7c09111..cf2d3be 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -139,9 +139,9 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, X_t.shape[1])) if a is None: - a = nx.full(n, 1 / n) + a = nx.full(n, 1 / n, type_as=X_s) if b is None: - b = nx.full(m, 1 / m) + b = nx.full(m, 1 / m, type_as=X_s) d = X_s.shape[1] @@ -238,9 +238,9 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, X_t.shape[1])) if a is None: - a = nx.full(n, 1 / n) + a = nx.full(n, 1 / n, type_as=X_s) if b is None: - b = nx.full(m, 1 / m) + b = nx.full(m, 1 / m, type_as=X_s) d = X_s.shape[1] diff --git a/test/conftest.py b/test/conftest.py index 876b525..987d98e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -11,6 +11,7 @@ import functools if jax: from jax.config import config + config.update("jax_enable_x64", True) backend_list = get_backend_list() @@ -18,16 +19,25 @@ backend_list = get_backend_list() @pytest.fixture(params=backend_list) def nx(request): backend = request.param - if backend.__name__ == "jax": - config.update("jax_enable_x64", True) yield backend - if backend.__name__ == "jax": - config.update("jax_enable_x64", False) - def skip_arg(arg, value, reason=None, getter=lambda x: x): + if isinstance(arg, tuple) or isinstance(arg, list): + n = len(arg) + else: + arg = (arg, ) + n = 1 + if n != 1 and (isinstance(value, tuple) or isinstance(value, list)): + pass + else: + value = (value, ) + if isinstance(getter, tuple) or isinstance(value, list): + pass + else: + getter = [getter] * n + if reason is None: reason = f"Param {arg} should be skipped for value {value}" @@ -35,7 +45,10 @@ def skip_arg(arg, value, reason=None, getter=lambda x: x): @functools.wraps(function) def wrapped(*args, **kwargs): - if arg in kwargs.keys() and getter(kwargs[arg]) == value: + if all( + arg[i] in kwargs.keys() and getter[i](kwargs[arg[i]]) == value[i] + for i in range(n) + ): pytest.skip(reason) return function(*args, **kwargs) diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 77b1234..cb85cb9 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -85,7 +85,6 @@ def test_wasserstein_1d(nx): np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) -@pytest.mark.parametrize('nx', backend_list) def test_wasserstein_1d_type_devices(nx): rng = np.random.RandomState(0) @@ -98,8 +97,7 @@ def test_wasserstein_1d_type_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - - print(tp.dtype) + print(nx.dtype_device(tp)) xb = nx.from_numpy(x, type_as=tp) rho_ub = nx.from_numpy(rho_u, type_as=tp) @@ -107,8 +105,7 @@ def test_wasserstein_1d_type_devices(nx): res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) - if not str(nx) == 'numpy': - assert res.dtype == xb.dtype + nx.assert_same_dtype_device(xb, res) def test_emd_1d_emd2_1d(): @@ -162,17 +159,14 @@ def test_emd1d_type_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - - print(tp.dtype) + print(nx.dtype_device(tp)) xb = nx.from_numpy(x, type_as=tp) rho_ub = nx.from_numpy(rho_u, type_as=tp) rho_vb = nx.from_numpy(rho_v, type_as=tp) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) - emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) - assert emd.dtype == xb.dtype - if not str(nx) == 'numpy': - assert emd2.dtype == xb.dtype + nx.assert_same_dtype_device(xb, emd) + nx.assert_same_dtype_device(xb, emd2) diff --git a/test/test_bregman.py b/test/test_bregman.py index edfe9c3..830052d 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -278,6 +278,51 @@ def test_sinkhorn_variants(nx): np.testing.assert_allclose(G0, G_green, atol=1e-5) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log"]) +@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str) +@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str) +def test_sinkhorn_variants_dtype_device(nx, method): + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ub = nx.from_numpy(u, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) + + nx.assert_same_dtype_device(Mb, Gb) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) +def test_sinkhorn2_variants_dtype_device(nx, method): + n = 100 + + x = np.random.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + ub = nx.from_numpy(u, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) + + nx.assert_same_dtype_device(Mb, lossb) + + @pytest.skip_backend("jax") def test_sinkhorn_variants_multi_b(nx): # test sinkhorn diff --git a/test/test_gromov.py b/test/test_gromov.py index bcbcc3a..c4bc04c 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -75,6 +75,41 @@ def test_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_gromov_dtype_device(nx):
+ # setup
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ C1b = nx.from_numpy(C1, type_as=tp)
+ C2b = nx.from_numpy(C2, type_as=tp)
+ pb = nx.from_numpy(p, type_as=tp)
+ qb = nx.from_numpy(q, type_as=tp)
+
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+
def test_gromov2_gradients():
n_samples = 50 # nb samples
@@ -168,6 +203,46 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_entropic_gromov_dtype_device(nx):
+ # setup
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ C1b = nx.from_numpy(C1, type_as=tp)
+ C2b = nx.from_numpy(C2, type_as=tp)
+ pb = nx.from_numpy(p, type_as=tp)
+ qb = nx.from_numpy(q, type_as=tp)
+
+ Gb = ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ )
+ gw_valb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ )
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+
def test_pointwise_gromov(nx):
n_samples = 50 # nb samples
diff --git a/test/test_ot.py b/test/test_ot.py index dc3930a..92f26a7 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -88,8 +88,7 @@ def test_emd_emd2_types_devices(nx): M = ot.dist(x, y) for tp in nx.__type_list__: - - print(tp.dtype) + print(nx.dtype_device(tp)) ab = nx.from_numpy(a, type_as=tp) Mb = nx.from_numpy(M, type_as=tp) @@ -98,9 +97,8 @@ def test_emd_emd2_types_devices(nx): w = ot.emd2(ab, ab, Mb) - assert Gb.dtype == Mb.dtype - if not str(nx) == 'numpy': - assert w.dtype == Mb.dtype + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, w) def test_emd2_gradients(): diff --git a/test/test_sliced.py b/test/test_sliced.py index 0bd74ec..245202c 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -139,6 +139,28 @@ def test_sliced_backend(nx): assert np.allclose(val0, valb) +def test_sliced_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + yb = nx.from_numpy(y, type_as=tp) + Pb = nx.from_numpy(P, type_as=tp) + + valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) + + nx.assert_same_dtype_device(xb, valb) + + def test_max_sliced_backend(nx): n = 100 @@ -167,3 +189,25 @@ def test_max_sliced_backend(nx): valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)) assert np.allclose(val0, valb) + + +def test_max_sliced_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + yb = nx.from_numpy(y, type_as=tp) + Pb = nx.from_numpy(P, type_as=tp) + + valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) + + nx.assert_same_dtype_device(xb, valb) |