"""Tests for module gromov """ # Author: Erwan Vautier # Nicolas Courty # # License: MIT License import numpy as np import ot def test_gromov(): n = 50 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s) xt = [xs[n - (i + 1)] for i in range(n)] xt = np.array(xt) p = ot.unif(n) q = ot.unif(n) C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) C1 /= C1.max() C2 /= C2.max() G = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4) # check constratints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov