From 6484c9ea301fc15ae53b4afe134941909f581ffe Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 14:11:48 +0200 Subject: Tests + contributions --- test/test_gromov.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) (limited to 'test/test_gromov.py') diff --git a/test/test_gromov.py b/test/test_gromov.py index fb86274..07cd874 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -143,3 +143,78 @@ def test_gromov_entropic_barycenter(): 'kl_loss', 2e-3, max_iter=100, tol=1e-3) np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) + +def test_fgw(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) + + xt = xs[::-1].copy() + + ys = np.random.randn(xs.shape[0],2) + yt= ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M=ot.dist(ys,yt) + M/=M.max() + + G = ot.gromov.fused_gromov_wasserstein(M,C1, C2, p, q, 'square_loss',alpha=0.5) + + # check constratints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence fgw + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence fgw + + +def test_fgw_barycenter(): + + ns = 50 + nt = 60 + + Xs, ys = ot.datasets.make_data_classif('3gauss', ns) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt) + + ys = np.random.randn(Xs.shape[0],2) + yt= np.random.randn(Xt.shape[0],2) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + + n_samples = 3 + X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5, + fixed_structure=False,fixed_features=False, + p=ot.unif(n_samples),loss_fun='square_loss', + max_iter=100, tol=1e-3) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + + xalea = np.random.randn(n_samples, 2) + init_C = ot.dist(xalea, xalea) + + X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],ps=[ot.unif(ns), ot.unif(nt)],lambdas=[.5, .5],alpha=0.5, + fixed_structure=True,init_C=init_C,fixed_features=False, + p=ot.unif(n_samples),loss_fun='square_loss', + max_iter=100, tol=1e-3) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + + init_X=np.random.randn(n_samples,ys.shape[1]) + + X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5, + fixed_structure=False,fixed_features=True, init_X=init_X, + p=ot.unif(n_samples),loss_fun='square_loss', + max_iter=100, tol=1e-3) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) -- cgit v1.2.3