summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 14:11:48 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 14:11:48 +0200
commit6484c9ea301fc15ae53b4afe134941909f581ffe (patch)
tree90a5e1487524696139722990e2f2d737d3206aef /test/test_gromov.py
parent11c2c26ff897e5763e714546e7021cffa8d673a7 (diff)
Tests + contributions
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index fb86274..07cd874 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -143,3 +143,78 @@ def test_gromov_entropic_barycenter():
'kl_loss', 2e-3,
max_iter=100, tol=1e-3)
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+
+def test_fgw():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+
+ xt = xs[::-1].copy()
+
+ ys = np.random.randn(xs.shape[0],2)
+ yt= ys[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M=ot.dist(ys,yt)
+ M/=M.max()
+
+ G = ot.gromov.fused_gromov_wasserstein(M,C1, C2, p, q, 'square_loss',alpha=0.5)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence fgw
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence fgw
+
+
+def test_fgw_barycenter():
+
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt)
+
+ ys = np.random.randn(Xs.shape[0],2)
+ yt= np.random.randn(Xt.shape[0],2)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5,
+ fixed_structure=False,fixed_features=False,
+ p=ot.unif(n_samples),loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+
+ xalea = np.random.randn(n_samples, 2)
+ init_C = ot.dist(xalea, xalea)
+
+ X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],ps=[ot.unif(ns), ot.unif(nt)],lambdas=[.5, .5],alpha=0.5,
+ fixed_structure=True,init_C=init_C,fixed_features=False,
+ p=ot.unif(n_samples),loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+
+ init_X=np.random.randn(n_samples,ys.shape[1])
+
+ X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5,
+ fixed_structure=False,fixed_features=True, init_X=init_X,
+ p=ot.unif(n_samples),loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))