summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-03-21 10:08:17 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-03-21 10:08:17 +0100
commit55aaf7874c651235d44c34b89337df7694e55014 (patch)
tree821b86edbf0e14c11885f466a249dc1263090e80 /test/test_gromov.py
parent927395b40dae98bcf027b601b6df48a4318cfef2 (diff)
add test gromov + debug sklearn Basestimator
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 625e62a..0384ee1 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -36,6 +36,18 @@ def test_gromov():
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-04, rtol=1e-4)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
def test_entropic_gromov():
n_samples = 50 # nb samples
@@ -64,3 +76,16 @@ def test_entropic_gromov():
p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+ gw, log = ot.gromov.entropic_gromov_wasserstein2(
+ C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov