summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-02-16 14:08:55 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-02-16 14:08:55 +0100
commitb4665fe3c12780af4228cb1fb7dc8e1159c81f63 (patch)
tree46e11035c4432acc27777e745bfd75b8ad7f7372 /test/test_gromov.py
parentd41ffdb24de5cb234237482b3f332b06514b10f0 (diff)
should pass tests now
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py31
1 files changed, 30 insertions, 1 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index e808292..625e62a 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -28,7 +28,36 @@ def test_gromov():
C1 /= C1.max()
C2 /= C2.max()
- G = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+ G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss')
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_entropic_gromov():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ G = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4)
# check constratints
np.testing.assert_allclose(