diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test_gromov.py | 26 |
1 files changed, 9 insertions, 17 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py index f218b74..70fa83f 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -11,14 +11,12 @@ import ot def test_gromov():
- np.random.seed(42)
-
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
xt = xs[::-1].copy()
@@ -58,14 +56,12 @@ def test_gromov(): def test_entropic_gromov():
- np.random.seed(42)
-
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
xt = xs[::-1].copy()
@@ -102,13 +98,11 @@ def test_entropic_gromov(): def test_gromov_barycenter():
- np.random.seed(42)
-
ns = 50
nt = 60
- Xs, ys = ot.datasets.make_data_classif('3gauss', ns)
- Xt, yt = ot.datasets.make_data_classif('3gauss2', nt)
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
@@ -131,12 +125,11 @@ def test_gromov_barycenter(): def test_gromov_entropic_barycenter():
- np.random.seed(42)
ns = 50
nt = 60
- Xs, ys = ot.datasets.make_data_classif('3gauss', ns)
- Xt, yt = ot.datasets.make_data_classif('3gauss2', nt)
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
@@ -159,14 +152,13 @@ def test_gromov_entropic_barycenter(): def test_fgw():
- np.random.seed(42)
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
xt = xs[::-1].copy()
@@ -217,8 +209,8 @@ def test_fgw_barycenter(): ns = 50
nt = 60
- Xs, ys = ot.datasets.make_data_classif('3gauss', ns)
- Xt, yt = ot.datasets.make_data_classif('3gauss2', nt)
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
ys = np.random.randn(Xs.shape[0], 2)
yt = np.random.randn(Xt.shape[0], 2)
|