From e1bd94bb7e85a0d2fd0fcd7642b06da12c1db6db Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 17:05:38 +0200 Subject: code review1 --- test/test_gromov.py | 57 +++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 13 deletions(-) (limited to 'test') diff --git a/test/test_gromov.py b/test/test_gromov.py index cd180d4..ec85abf 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -2,6 +2,7 @@ # Author: Erwan Vautier # Nicolas Courty +# Titouan Vayer # # License: MIT License @@ -10,6 +11,8 @@ import ot def test_gromov(): + np.random.seed(42) + n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -36,6 +39,11 @@ def test_gromov(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov + Id = (1 / n_samples) * np.eye(n_samples, n_samples) + + np.testing.assert_allclose( + G, np.flipud(Id), atol=1e-04) + gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True) G = log['T'] @@ -50,6 +58,8 @@ def test_gromov(): def test_entropic_gromov(): + np.random.seed(42) + n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -92,6 +102,7 @@ def test_entropic_gromov(): def test_gromov_barycenter(): + np.random.seed(42) ns = 50 nt = 60 @@ -120,7 +131,7 @@ def test_gromov_barycenter(): def test_gromov_entropic_barycenter(): - + np.random.seed(42) ns = 50 nt = 60 @@ -148,6 +159,8 @@ def test_gromov_entropic_barycenter(): def test_fgw(): + np.random.seed(42) + n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -180,8 +193,26 @@ def test_fgw(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence fgw + Id = (1 / n_samples) * np.eye(n_samples, n_samples) + + np.testing.assert_allclose( + G, np.flipud(Id), atol=1e-04) # cf convergence gromov + + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) + + G = log['T'] + + np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + + # check constratints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence gromov + def test_fgw_barycenter(): + np.random.seed(42) ns = 50 nt = 60 @@ -196,28 +227,28 @@ def test_fgw_barycenter(): C2 = ot.dist(Xt) n_samples = 3 - X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) xalea = np.random.randn(n_samples, 2) init_C = ot.dist(xalea, xalea) - X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5, - fixed_structure=True, init_C=init_C, fixed_features=False, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5, + fixed_structure=True, init_C=init_C, fixed_features=False, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) init_X = np.random.randn(n_samples, ys.shape[1]) - X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_X, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_X, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) -- cgit v1.2.3