From 3007f1da1094f93fa4216386666085cf60316b04 Mon Sep 17 00:00:00 2001 From: Nicolas Courty Date: Thu, 31 Aug 2017 16:44:18 +0200 Subject: Minor corrections suggested by @agramfort + new barycenter example + test function --- test/test_gromov.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 test/test_gromov.py (limited to 'test/test_gromov.py') diff --git a/test/test_gromov.py b/test/test_gromov.py new file mode 100644 index 0000000..75eeaab --- /dev/null +++ b/test/test_gromov.py @@ -0,0 +1,38 @@ +"""Tests for module gromov """ + +# Author: Erwan Vautier +# Nicolas Courty +# +# License: MIT License + +import numpy as np +import ot + + +def test_gromov(): + n = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s) + + xt = [xs[n - (i + 1)] for i in range(n)] + xt = np.array(xt) + + p = ot.unif(n) + q = ot.unif(n) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + G = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4) + + # check constratints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence gromov -- cgit v1.2.3