diff options
author | eloitanguy <69361683+eloitanguy@users.noreply.github.com> | 2022-05-11 08:57:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-11 08:57:54 +0200 |
commit | d6bf10d8502b1c69f58f009b16634a110053eca1 (patch) | |
tree | 8d74efb46fa79063f7c2285f1d99c41b5b2b9ac3 /test | |
parent | c1ccfc45350f8db3fa78d91b84eb4286bcf36e69 (diff) |
[WIP] Graphical tweaks for GWB + fixed seed method for the partial gromov test (#376)
* GWB first solver version
* tests + example for gwb (untested) + free_bar doc fix
* improved doc, fixed minor bugs, better example visu
* minor doc + visu fixes
* plot GWB pep8 fix
* fixed partial gromov test reproductibility
* added an animation for the GWB visu
* added PR num
* minor doc fixes + better gwb logo
* GWB graphical tweaks + better seed method for partial gromov test
* fixed PR number
* refixed seed issue
* seed fix fix fix
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rwxr-xr-x | test/test_partial.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/test/test_partial.py b/test/test_partial.py index e07377b..33fc259 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -137,7 +137,7 @@ def test_partial_wasserstein(): def test_partial_gromov_wasserstein(): - np.random.seed(42) + rng = np.random.RandomState(seed=42) n_samples = 20 # nb samples n_noise = 10 # nb of samples (noise) @@ -150,11 +150,11 @@ def test_partial_gromov_wasserstein(): mu_t = np.array([0, 0, 0]) cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) - xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, rng) + xs = np.concatenate((xs, ((rng.rand(n_noise, 2) + 1) * 4)), axis=0) P = sp.linalg.sqrtm(cov_t) - xt = np.random.randn(n_samples, 3).dot(P) + mu_t - xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0) + xt = rng.randn(n_samples, 3).dot(P) + mu_t + xt = np.concatenate((xt, ((rng.rand(n_noise, 3) + 1) * 10)), axis=0) xt2 = xs[::-1].copy() C1 = ot.dist(xs, xs) |