summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-03-21 10:28:19 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-03-21 10:28:19 +0100
commit7095e03eb339bcf32d91c5a8857ecc3f3d0c45c0 (patch)
treea96079c5e54a6925dac4914c7c595dce63bf5097 /test/test_gromov.py
parent64ef33d09906a1aebd3c8294ffd7720475ab926b (diff)
gtomov barycenter tests
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py31
1 files changed, 29 insertions, 2 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 0dfd54e..d865380 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -40,7 +40,7 @@ def test_gromov():
G = log['T']
- np.testing.assert_allclose(gw, 0, atol=1e-2, rtol=1e-2)
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
# check constratints
np.testing.assert_allclose(
@@ -82,10 +82,37 @@ def test_entropic_gromov():
G = log['T']
- np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e1)
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
# check constratints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_gromov_barycenter():
+
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.get_data_classif('3gauss', ns)
+ Xt, yt = ot.datasets.get_data_classif('3gauss2', nt)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'kl_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))