summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 81138ca..56414a8 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -9,6 +9,8 @@
import numpy as np
import ot
+import pytest
+
def test_gromov():
n_samples = 50 # nb samples
@@ -128,9 +130,10 @@ def test_gromov_barycenter():
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+@pytest.mark.filterwarnings("ignore:divide")
def test_gromov_entropic_barycenter():
- ns = 50
- nt = 60
+ ns = 20
+ nt = 30
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
@@ -138,19 +141,19 @@ def test_gromov_entropic_barycenter():
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
- n_samples = 3
+ n_samples = 2
Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
[ot.unif(ns), ot.unif(nt)
], ot.unif(n_samples), [.5, .5],
- 'square_loss', 2e-3,
- max_iter=100, tol=1e-3,
+ 'square_loss', 1e-3,
+ max_iter=50, tol=1e-5,
verbose=True)
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
[ot.unif(ns), ot.unif(nt)
], ot.unif(n_samples), [.5, .5],
- 'kl_loss', 2e-3,
+ 'kl_loss', 1e-3,
max_iter=100, tol=1e-3)
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))