From 0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Fri, 5 Nov 2021 15:57:08 +0100 Subject: [MRG] Tests with types/device on sliced/bregman/gromov functions (#303) * First draft : making pytest use gpu for torch testing * bug solve * Revert "bug solve" This reverts commit 29b013abd162f8693128f26d8129186b79923609. * Revert "First draft : making pytest use gpu for torch testing" This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a. * sliced * sliced * ot 1dsolver * bregman * better print * jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them * gromov & entropic gromov --- test/test_gromov.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) (limited to 'test/test_gromov.py') diff --git a/test/test_gromov.py b/test/test_gromov.py index bcbcc3a..c4bc04c 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -75,6 +75,41 @@ def test_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_gromov_dtype_device(nx): + # setup + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + C1b = nx.from_numpy(C1, type_as=tp) + C2b = nx.from_numpy(C2, type_as=tp) + pb = nx.from_numpy(p, type_as=tp) + qb = nx.from_numpy(q, type_as=tp) + + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + def test_gromov2_gradients(): n_samples = 50 # nb samples @@ -168,6 +203,46 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +@pytest.skip_backend("jax", reason="test very slow with jax backend") +def test_entropic_gromov_dtype_device(nx): + # setup + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + C1b = nx.from_numpy(C1, type_as=tp) + C2b = nx.from_numpy(C2, type_as=tp) + pb = nx.from_numpy(p, type_as=tp) + qb = nx.from_numpy(q, type_as=tp) + + Gb = ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True + ) + gw_valb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True + ) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + def test_pointwise_gromov(nx): n_samples = 50 # nb samples -- cgit v1.2.3