diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2023-06-09 20:26:52 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-09 20:26:52 +0200 |
commit | 6c1e1f3e064165d37e22acc866c6fff56e3ab6ad (patch) | |
tree | 354aa9c554a9e7490c93bd2ac579675f1b933329 /examples/sliced-wasserstein | |
parent | 5faa4fbdb1a64351a42d31dd6f54f0402c29c405 (diff) |
[MRG] Update tests and documentation (#484)
* remove old macos and windows tets update requirements
* speedup ssw and continuaous ot exmaples
* speedup regpath and variane
* speedup conv 2d example + continuous stick
* speedup regpath
Diffstat (limited to 'examples/sliced-wasserstein')
-rw-r--r-- | examples/sliced-wasserstein/plot_variance.py | 8 | ||||
-rw-r--r-- | examples/sliced-wasserstein/plot_variance_ssw.py | 8 |
2 files changed, 8 insertions, 8 deletions
diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py index 2293247..77df2f5 100644 --- a/examples/sliced-wasserstein/plot_variance.py +++ b/examples/sliced-wasserstein/plot_variance.py @@ -29,7 +29,7 @@ import ot # %% parameters and data generation -n = 500 # nb samples +n = 200 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -58,9 +58,9 @@ pl.title('Source and target distributions') # Sliced Wasserstein distance for different seeds and number of projections # ------------------------------------------------------------------------- -n_seed = 50 -n_projections_arr = np.logspace(0, 3, 25, dtype=int) -res = np.empty((n_seed, 25)) +n_seed = 20 +n_projections_arr = np.logspace(0, 3, 10, dtype=int) +res = np.empty((n_seed, 10)) # %% Compute statistics for seed in range(n_seed): diff --git a/examples/sliced-wasserstein/plot_variance_ssw.py b/examples/sliced-wasserstein/plot_variance_ssw.py index f5fc35f..246b2a8 100644 --- a/examples/sliced-wasserstein/plot_variance_ssw.py +++ b/examples/sliced-wasserstein/plot_variance_ssw.py @@ -28,7 +28,7 @@ import ot # %% parameters and data generation -n = 500 # nb samples +n = 200 # nb samples xs = np.random.randn(n, 3) xt = np.random.randn(n, 3) @@ -81,9 +81,9 @@ pl.title("Source and Target distribution") # Spherical Sliced Wasserstein for different seeds and number of projections # -------------------------------------------------------------------------- -n_seed = 50 -n_projections_arr = np.logspace(0, 3, 25, dtype=int) -res = np.empty((n_seed, 25)) +n_seed = 20 +n_projections_arr = np.logspace(0, 3, 10, dtype=int) +res = np.empty((n_seed, 10)) # %% Compute statistics for seed in range(n_seed): |