diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-03-21 10:08:17 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-03-21 10:08:17 +0100 |
commit | 55aaf7874c651235d44c34b89337df7694e55014 (patch) | |
tree | 821b86edbf0e14c11885f466a249dc1263090e80 /test/test_gromov.py | |
parent | 927395b40dae98bcf027b601b6df48a4318cfef2 (diff) |
add test gromov + debug sklearn Basestimator
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r-- | test/test_gromov.py | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py index 625e62a..0384ee1 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -36,6 +36,18 @@ def test_gromov(): np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-04, rtol=1e-4)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
def test_entropic_gromov():
n_samples = 50 # nb samples
@@ -64,3 +76,16 @@ def test_entropic_gromov(): p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+ gw, log = ot.gromov.entropic_gromov_wasserstein2(
+ C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
|