From b4665fe3c12780af4228cb1fb7dc8e1159c81f63 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Fri, 16 Feb 2018 14:08:55 +0100 Subject: should pass tests now --- test/test_gromov.py | 31 ++++++++++++++++++++++++++++++- test/test_plot.py | 2 ++ 2 files changed, 32 insertions(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_gromov.py b/test/test_gromov.py index e808292..625e62a 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -28,7 +28,36 @@ def test_gromov(): C1 /= C1.max() C2 /= C2.max() - G = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4) + G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss') + + # check constratints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence gromov + + +def test_entropic_gromov(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + G = ot.gromov.entropic_gromov_wasserstein( + C1, C2, p, q, 'square_loss', epsilon=5e-4) # check constratints np.testing.assert_allclose( diff --git a/test/test_plot.py b/test/test_plot.py index f7debee..a50ed14 100644 --- a/test/test_plot.py +++ b/test/test_plot.py @@ -12,6 +12,7 @@ matplotlib.use('Agg') def test_plot1D_mat(): import ot + import ot.plot n_bins = 100 # nb bins @@ -32,6 +33,7 @@ def test_plot1D_mat(): def test_plot2D_samples_mat(): import ot + import ot.plot n_bins = 50 # nb samples -- cgit v1.2.3