summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-11-05 15:57:08 +0100
committerGitHub <noreply@github.com>2021-11-05 15:57:08 +0100
commit0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 (patch)
treeb0c0fbce0109ba460a67a6356dc0ff03e2b3c1d5 /test/test_gromov.py
parent0e431c203a66c6d48e6bb1efeda149460472a0f0 (diff)
[MRG] Tests with types/device on sliced/bregman/gromov functions (#303)
* First draft : making pytest use gpu for torch testing * bug solve * Revert "bug solve" This reverts commit 29b013abd162f8693128f26d8129186b79923609. * Revert "First draft : making pytest use gpu for torch testing" This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a. * sliced * sliced * ot 1dsolver * bregman * better print * jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them * gromov & entropic gromov
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index bcbcc3a..c4bc04c 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -75,6 +75,41 @@ def test_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_gromov_dtype_device(nx):
+ # setup
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ C1b = nx.from_numpy(C1, type_as=tp)
+ C2b = nx.from_numpy(C2, type_as=tp)
+ pb = nx.from_numpy(p, type_as=tp)
+ qb = nx.from_numpy(q, type_as=tp)
+
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+
def test_gromov2_gradients():
n_samples = 50 # nb samples
@@ -168,6 +203,46 @@ def test_entropic_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_entropic_gromov_dtype_device(nx):
+ # setup
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ C1b = nx.from_numpy(C1, type_as=tp)
+ C2b = nx.from_numpy(C2, type_as=tp)
+ pb = nx.from_numpy(p, type_as=tp)
+ qb = nx.from_numpy(q, type_as=tp)
+
+ Gb = ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ )
+ gw_valb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ )
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+
def test_pointwise_gromov(nx):
n_samples = 50 # nb samples