summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py119
1 files changed, 111 insertions, 8 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index c4bc04c..4b995d5 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -9,7 +9,7 @@
import numpy as np
import ot
from ot.backend import NumpyBackend
-from ot.backend import torch
+from ot.backend import torch, tf
import pytest
@@ -54,9 +54,12 @@ def test_gromov(nx):
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
+ gwb = nx.to_numpy(gwb)
gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ gw_valb = nx.to_numpy(
+ ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ )
G = log['T']
Gb = nx.to_numpy(logb['T'])
@@ -110,6 +113,45 @@ def test_gromov_dtype_device(nx):
nx.assert_same_dtype_device(C1b, gw_valb)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_gromov_device_tf():
+ nx = ot.backend.TensorflowBackend()
+ n_samples = 50 # nb samples
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+ xt = xs[::-1].copy()
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+ assert nx.dtype_device(Gb)[1].startswith("GPU")
+
+
def test_gromov2_gradients():
n_samples = 50 # nb samples
@@ -147,6 +189,7 @@ def test_gromov2_gradients():
@pytest.skip_backend("jax", reason="test very slow with jax backend")
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov(nx):
n_samples = 50 # nb samples
@@ -188,6 +231,7 @@ def test_entropic_gromov(nx):
C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True)
+ gwb = nx.to_numpy(gwb)
G = log['T']
Gb = nx.to_numpy(logb['T'])
@@ -204,6 +248,7 @@ def test_entropic_gromov(nx):
@pytest.skip_backend("jax", reason="test very slow with jax backend")
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov_dtype_device(nx):
# setup
n_samples = 50 # nb samples
@@ -287,8 +332,8 @@ def test_pointwise_gromov(nx):
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
- np.testing.assert_allclose(logb['gw_dist_estimated'], 0.0, atol=1e-08)
- np.testing.assert_allclose(logb['gw_dist_std'], 0.0, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.0, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0, atol=1e-08)
G, log = ot.gromov.pointwise_gromov_wasserstein(
C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
@@ -298,10 +343,11 @@ def test_pointwise_gromov(nx):
Gb = nx.to_numpy(nx.todense(Gb))
np.testing.assert_allclose(G, Gb, atol=1e-06)
- np.testing.assert_allclose(logb['gw_dist_estimated'], 0.10342276348494964, atol=1e-8)
- np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8)
+ np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.10342276348494964, atol=1e-8)
+ np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0015952535464736394, atol=1e-8)
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
@pytest.skip_backend("jax", reason="test very slow with jax backend")
def test_sampled_gromov(nx):
n_samples = 50 # nb samples
@@ -346,8 +392,8 @@ def test_sampled_gromov(nx):
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
- np.testing.assert_allclose(logb['gw_dist_estimated'], 0.05679474884977278, atol=1e-08)
- np.testing.assert_allclose(logb['gw_dist_std'], 0.0005986592106971995, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.05679474884977278, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0005986592106971995, atol=1e-08)
def test_gromov_barycenter(nx):
@@ -381,6 +427,20 @@ def test_gromov_barycenter(nx):
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+ # test of gromov_barycenters with `log` on
+ Cb_, err_ = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cbb_, errb_ = ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cbb_ = nx.to_numpy(Cbb_)
+ np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
+ np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
+ np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))
+
Cb2 = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
'kl_loss', max_iter=100, tol=1e-3, random_state=42
@@ -392,6 +452,20 @@ def test_gromov_barycenter(nx):
np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
+ # test of gromov_barycenters with `log` on
+ Cb2_, err2_ = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cb2b_, err2b_ = ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cb2b_ = nx.to_numpy(Cb2b_)
+ np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
+ np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
+ np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))
+
@pytest.mark.filterwarnings("ignore:divide")
def test_gromov_entropic_barycenter(nx):
@@ -425,6 +499,20 @@ def test_gromov_entropic_barycenter(nx):
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+ # test of entropic_gromov_barycenters with `log` on
+ Cb_, err_ = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cbb_, errb_ = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cbb_ = nx.to_numpy(Cbb_)
+ np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
+ np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
+ np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))
+
Cb2 = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
@@ -436,6 +524,20 @@ def test_gromov_entropic_barycenter(nx):
np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
+ # test of entropic_gromov_barycenters with `log` on
+ Cb2_, err2_ = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cb2b_, err2b_ = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cb2b_ = nx.to_numpy(Cb2b_)
+ np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
+ np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
+ np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))
+
def test_fgw(nx):
n_samples = 50 # nb samples
@@ -486,6 +588,7 @@ def test_fgw(nx):
fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ fgwb = nx.to_numpy(fgwb)
G = log['T']
Gb = nx.to_numpy(logb['T'])