summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
authorNicolas Courty <Nico@pc-mna-08.univ-ubs.fr>2017-08-31 16:44:18 +0200
committerNicolas Courty <Nico@pc-mna-08.univ-ubs.fr>2017-08-31 16:44:18 +0200
commit3007f1da1094f93fa4216386666085cf60316b04 (patch)
tree5e07b1674769403f2e09476b7d73f1e00a845384 /test/test_gromov.py
parent0a68bf4e83ee9092c3f3878115fea894922d7d56 (diff)
Minor corrections suggested by @agramfort + new barycenter example + test function
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
new file mode 100644
index 0000000..75eeaab
--- /dev/null
+++ b/test/test_gromov.py
@@ -0,0 +1,38 @@
+"""Tests for module gromov """
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+
+
+def test_gromov():
+ n = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
+
+ xt = [xs[n - (i + 1)] for i in range(n)]
+ xt = np.array(xt)
+
+ p = ot.unif(n)
+ q = ot.unif(n)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ G = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov