summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 17:05:38 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 17:05:48 +0200
commite1bd94bb7e85a0d2fd0fcd7642b06da12c1db6db (patch)
tree1e85920b878ab715d211db56f99e25bfa2482fd3 /test
parentd4320382fa8873d15dcaec7adca3a4723c142515 (diff)
code review1
Diffstat (limited to 'test')
-rw-r--r--test/test_gromov.py57
1 files changed, 44 insertions, 13 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index cd180d4..ec85abf 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -2,6 +2,7 @@
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
#
# License: MIT License
@@ -10,6 +11,8 @@ import ot
def test_gromov():
+ np.random.seed(42)
+
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -36,6 +39,11 @@ def test_gromov():
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+ Id = (1 / n_samples) * np.eye(n_samples, n_samples)
+
+ np.testing.assert_allclose(
+ G, np.flipud(Id), atol=1e-04)
+
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
G = log['T']
@@ -50,6 +58,8 @@ def test_gromov():
def test_entropic_gromov():
+ np.random.seed(42)
+
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -92,6 +102,7 @@ def test_entropic_gromov():
def test_gromov_barycenter():
+ np.random.seed(42)
ns = 50
nt = 60
@@ -120,7 +131,7 @@ def test_gromov_barycenter():
def test_gromov_entropic_barycenter():
-
+ np.random.seed(42)
ns = 50
nt = 60
@@ -148,6 +159,8 @@ def test_gromov_entropic_barycenter():
def test_fgw():
+ np.random.seed(42)
+
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -180,8 +193,26 @@ def test_fgw():
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence fgw
+ Id = (1 / n_samples) * np.eye(n_samples, n_samples)
+
+ np.testing.assert_allclose(
+ G, np.flipud(Id), atol=1e-04) # cf convergence gromov
+
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
def test_fgw_barycenter():
+ np.random.seed(42)
ns = 50
nt = 60
@@ -196,28 +227,28 @@ def test_fgw_barycenter():
C2 = ot.dist(Xt)
n_samples = 3
- X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=False,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=False,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
xalea = np.random.randn(n_samples, 2)
init_C = ot.dist(xalea, xalea)
- X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5,
- fixed_structure=True, init_C=init_C, fixed_features=False,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5,
+ fixed_structure=True, init_C=init_C, fixed_features=False,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
init_X = np.random.randn(n_samples, ys.shape[1])
- X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=True, init_X=init_X,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_X,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))