diff options
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r-- | test/test_gromov.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py index c4bc04c..5c181f2 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -54,9 +54,12 @@ def test_gromov(nx): gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
+ gwb = nx.to_numpy(gwb)
gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ gw_valb = nx.to_numpy(
+ ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ )
G = log['T']
Gb = nx.to_numpy(logb['T'])
@@ -188,6 +191,7 @@ def test_entropic_gromov(nx): C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True)
+ gwb = nx.to_numpy(gwb)
G = log['T']
Gb = nx.to_numpy(logb['T'])
@@ -287,8 +291,8 @@ def test_pointwise_gromov(nx): np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
- np.testing.assert_allclose(logb['gw_dist_estimated'], 0.0, atol=1e-08)
- np.testing.assert_allclose(logb['gw_dist_std'], 0.0, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.0, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0, atol=1e-08)
G, log = ot.gromov.pointwise_gromov_wasserstein(
C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
@@ -298,8 +302,8 @@ def test_pointwise_gromov(nx): Gb = nx.to_numpy(nx.todense(Gb))
np.testing.assert_allclose(G, Gb, atol=1e-06)
- np.testing.assert_allclose(logb['gw_dist_estimated'], 0.10342276348494964, atol=1e-8)
- np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8)
+ np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.10342276348494964, atol=1e-8)
+ np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0015952535464736394, atol=1e-8)
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@@ -346,8 +350,8 @@ def test_sampled_gromov(nx): np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
- np.testing.assert_allclose(logb['gw_dist_estimated'], 0.05679474884977278, atol=1e-08)
- np.testing.assert_allclose(logb['gw_dist_std'], 0.0005986592106971995, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.05679474884977278, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0005986592106971995, atol=1e-08)
def test_gromov_barycenter(nx):
@@ -486,6 +490,7 @@ def test_fgw(nx): fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ fgwb = nx.to_numpy(fgwb)
G = log['T']
Gb = nx.to_numpy(logb['T'])
|