diff options
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r-- | test/test_gromov.py | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py index 625e62a..0384ee1 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -36,6 +36,18 @@ def test_gromov(): np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-04, rtol=1e-4)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
def test_entropic_gromov():
n_samples = 50 # nb samples
@@ -64,3 +76,16 @@ def test_entropic_gromov(): p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+ gw, log = ot.gromov.entropic_gromov_wasserstein2(
+ C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
|