summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-06-17 11:46:37 +0200
committerGitHub <noreply@github.com>2021-06-17 11:46:37 +0200
commit8ef3341a472909f223ec0f678f11f136f55c1406 (patch)
tree162cbd7f7a0e3bd87c2e3b5fe61f70f25ec951c5 /test/test_gromov.py
parent2dbeeda9308029a8e8db56bed07d48f4d5718efb (diff)
[MRG] Speedup tests (#262)
* speedup tests * add color to tests and timings * add test unbalanced * stupid missing -
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 81138ca..56414a8 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -9,6 +9,8 @@
import numpy as np
import ot
+import pytest
+
def test_gromov():
n_samples = 50 # nb samples
@@ -128,9 +130,10 @@ def test_gromov_barycenter():
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+@pytest.mark.filterwarnings("ignore:divide")
def test_gromov_entropic_barycenter():
- ns = 50
- nt = 60
+ ns = 20
+ nt = 30
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
@@ -138,19 +141,19 @@ def test_gromov_entropic_barycenter():
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
- n_samples = 3
+ n_samples = 2
Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
[ot.unif(ns), ot.unif(nt)
], ot.unif(n_samples), [.5, .5],
- 'square_loss', 2e-3,
- max_iter=100, tol=1e-3,
+ 'square_loss', 1e-3,
+ max_iter=50, tol=1e-5,
verbose=True)
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
[ot.unif(ns), ot.unif(nt)
], ot.unif(n_samples), [.5, .5],
- 'kl_loss', 2e-3,
+ 'kl_loss', 1e-3,
max_iter=100, tol=1e-3)
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))