summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
authorNicolas Courty <Nico@pc-mna-08.univ-ubs.fr>2017-08-31 16:44:18 +0200
committerNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2017-09-01 11:09:13 +0200
commit8c525174bb664cafa98dfff73dce9d42d7818f71 (patch)
treed353a0952f29c8cf3cb71bdd198f9acc4afa58da /test/test_gromov.py
parent93dee553a3dd5d6e3c5a5d325bb6333e8eb24dee (diff)
Minor corrections suggested by @agramfort + new barycenter example + test function
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
new file mode 100644
index 0000000..75eeaab
--- /dev/null
+++ b/test/test_gromov.py
@@ -0,0 +1,38 @@
+"""Tests for module gromov """
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+
+
+def test_gromov():
+ n = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
+
+ xt = [xs[n - (i + 1)] for i in range(n)]
+ xt = np.array(xt)
+
+ p = ot.unif(n)
+ q = ot.unif(n)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ G = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov