summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 43da9fc..81138ca 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -181,7 +181,7 @@ def test_fgw():
M = ot.dist(ys, yt)
M /= M.max()
- G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5)
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
# check constratints
np.testing.assert_allclose(
@@ -242,9 +242,9 @@ def test_fgw_barycenter():
init_X = np.random.randn(n_samples, ys.shape[1])
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=True, init_X=init_X,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
+ X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_X,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3, log=True)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))