diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2023-06-09 20:26:52 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-09 20:26:52 +0200 |
commit | 6c1e1f3e064165d37e22acc866c6fff56e3ab6ad (patch) | |
tree | 354aa9c554a9e7490c93bd2ac579675f1b933329 | |
parent | 5faa4fbdb1a64351a42d31dd6f54f0402c29c405 (diff) |
[MRG] Update tests and documentation (#484)
* remove old macos and windows tets update requirements
* speedup ssw and continuaous ot exmaples
* speedup regpath and variane
* speedup conv 2d example + continuous stick
* speedup regpath
-rw-r--r-- | .github/workflows/build_tests.yml | 16 | ||||
-rw-r--r-- | RELEASES.md | 1 | ||||
-rw-r--r-- | docs/Makefile | 2 | ||||
-rw-r--r-- | examples/backends/plot_sliced_wass_grad_flow_pytorch.py | 8 | ||||
-rw-r--r-- | examples/backends/plot_ssw_unif_torch.py | 12 | ||||
-rw-r--r-- | examples/backends/plot_stoch_continuous_ot_pytorch.py | 6 | ||||
-rw-r--r-- | examples/barycenters/plot_convolutional_barycenter.py | 8 | ||||
-rw-r--r-- | examples/plot_OT_1D_smooth.py | 2 | ||||
-rw-r--r-- | examples/sliced-wasserstein/plot_variance.py | 8 | ||||
-rw-r--r-- | examples/sliced-wasserstein/plot_variance_ssw.py | 8 | ||||
-rw-r--r-- | examples/unbalanced-partial/plot_regpath.py | 14 |
11 files changed, 44 insertions, 41 deletions
diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index a2d26f1..a5e876b 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -83,7 +83,7 @@ jobs: pip install -e . - name: Run tests run: | - python -m pytest --durations=20 -v test/ ot/ --ignore ot/gpu/ --color=yes + python -m pytest --durations=20 -v test/ ot/ --color=yes macos: @@ -92,7 +92,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.10"] steps: - uses: actions/checkout@v1 @@ -107,10 +107,10 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt - pip install pytest "pytest-cov<2.6" + pip install pytest - name: Run tests run: | - python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes + python -m pytest --durations=20 -v test/ ot/ --color=yes windows: @@ -119,7 +119,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.10"] steps: - uses: actions/checkout@v1 @@ -151,8 +151,8 @@ jobs: - name: Install dependencies run: | python -m pip install -r .github/requirements_test_windows.txt - python -m pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html - python -m pip install pytest "pytest-cov<2.6" + python -m pip3 install torch torchvision torchaudio + python -m pip install pytest - name: Run tests run: | - python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes + python -m pytest --durations=20 -v test/ ot/ --color=yes diff --git a/RELEASES.md b/RELEASES.md index 5c50423..cd0bcde 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,6 +7,7 @@ - Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459) - Add tests on GPU for master branch and approved PR (PR #473) - Add `median` method to all inherited classes of `backend.Backend` (PR #472) +- Update tests for macOS and Windows, speedup documentation (PR #484) #### Closed issues diff --git a/docs/Makefile b/docs/Makefile index 9892785..5aff9cd 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -50,6 +50,8 @@ help: .PHONY: clean clean: rm -rf $(BUILDDIR)/* + rm -rf source/gen_modules/* + rm -rf source/auto_examples/* .PHONY: html html: diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py index 07a4926..7cbfd98 100644 --- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -40,8 +40,8 @@ import torch import ot import matplotlib.animation as animation -I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2] -I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2] +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::5, ::5, 2] +I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::5, ::5, 2] sz = I2.shape[0] XX, YY = np.meshgrid(np.arange(sz), np.arange(sz)) @@ -67,7 +67,7 @@ x2_torch = torch.tensor(x2).to(device=device) lr = 1e3 -nb_iter_max = 100 +nb_iter_max = 50 x_all = np.zeros((nb_iter_max, x1.shape[0], 2)) @@ -129,7 +129,7 @@ xbinit = np.random.randn(500, 2) * 10 + 16 xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True) lr = 1e3 -nb_iter_max = 100 +nb_iter_max = 50 x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2)) diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py index afe3fa6..7459cf6 100644 --- a/examples/backends/plot_ssw_unif_torch.py +++ b/examples/backends/plot_ssw_unif_torch.py @@ -35,7 +35,7 @@ import ot torch.manual_seed(1) -N = 1000 +N = 500 x0 = torch.rand(N, 3) x0 = F.normalize(x0, dim=-1) @@ -72,8 +72,8 @@ ax.legend() x = x0.clone() x.requires_grad_(True) -n_iter = 500 -lr = 100 +n_iter = 100 +lr = 150 losses = [] xvisu = torch.zeros(n_iter, N, 3) @@ -82,7 +82,7 @@ for i in range(n_iter): sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500) grad_x = torch.autograd.grad(sw, x)[0] - x = x - lr * grad_x + x = x - lr * grad_x / np.sqrt(i / 10 + 1) x = F.normalize(x, p=2, dim=1) losses.append(sw.item()) @@ -102,7 +102,7 @@ pl.xlabel("Iterations") # Plot trajectories of generated samples along iterations # ------------------------------------------------------- -ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499] +ivisu = [0, 10, 20, 30, 40, 50, 60, 70, 80] fig = pl.figure(3, (10, 10)) for i in range(9): @@ -149,5 +149,5 @@ ax.set_ylim((-1.5, 1.5)) ax.set_title('Iter. {}'.format(ivisu[i])) -ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000) +ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=200, repeat_delay=2000) # %% diff --git a/examples/backends/plot_stoch_continuous_ot_pytorch.py b/examples/backends/plot_stoch_continuous_ot_pytorch.py index 714a5d3..e642986 100644 --- a/examples/backends/plot_stoch_continuous_ot_pytorch.py +++ b/examples/backends/plot_stoch_continuous_ot_pytorch.py @@ -27,8 +27,8 @@ import ot.plot torch.manual_seed(42) np.random.seed(42) -n_source_samples = 10000 -n_target_samples = 10000 +n_source_samples = 1000 +n_target_samples = 1000 theta = 2 * np.pi / 20 noise_level = 0.1 @@ -89,7 +89,7 @@ reg = 1 optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005) # number of iteration -n_iter = 1000 +n_iter = 500 n_batch = 500 diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py index 3721f31..143b3a6 100644 --- a/examples/barycenters/plot_convolutional_barycenter.py +++ b/examples/barycenters/plot_convolutional_barycenter.py @@ -29,10 +29,10 @@ import ot this_file = os.path.realpath('__file__') data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') -f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2] -f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2] -f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2] -f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] +f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[::2, ::2, 2] +f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[::2, ::2, 2] +f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[::2, ::2, 2] +f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[::2, ::2, 2] f1 = f1 / np.sum(f1) f2 = f2 / np.sum(f2) diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py index 626938c..4f233fe 100644 --- a/examples/plot_OT_1D_smooth.py +++ b/examples/plot_OT_1D_smooth.py @@ -14,7 +14,7 @@ sparsity-constrained OT, together with their visualizations. # # License: MIT License -# sphinx_gallery_thumbnail_number = 6 +# sphinx_gallery_thumbnail_number = 5 import numpy as np import matplotlib.pylab as pl diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py index 2293247..77df2f5 100644 --- a/examples/sliced-wasserstein/plot_variance.py +++ b/examples/sliced-wasserstein/plot_variance.py @@ -29,7 +29,7 @@ import ot # %% parameters and data generation -n = 500 # nb samples +n = 200 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -58,9 +58,9 @@ pl.title('Source and target distributions') # Sliced Wasserstein distance for different seeds and number of projections # ------------------------------------------------------------------------- -n_seed = 50 -n_projections_arr = np.logspace(0, 3, 25, dtype=int) -res = np.empty((n_seed, 25)) +n_seed = 20 +n_projections_arr = np.logspace(0, 3, 10, dtype=int) +res = np.empty((n_seed, 10)) # %% Compute statistics for seed in range(n_seed): diff --git a/examples/sliced-wasserstein/plot_variance_ssw.py b/examples/sliced-wasserstein/plot_variance_ssw.py index f5fc35f..246b2a8 100644 --- a/examples/sliced-wasserstein/plot_variance_ssw.py +++ b/examples/sliced-wasserstein/plot_variance_ssw.py @@ -28,7 +28,7 @@ import ot # %% parameters and data generation -n = 500 # nb samples +n = 200 # nb samples xs = np.random.randn(n, 3) xt = np.random.randn(n, 3) @@ -81,9 +81,9 @@ pl.title("Source and Target distribution") # Spherical Sliced Wasserstein for different seeds and number of projections # -------------------------------------------------------------------------- -n_seed = 50 -n_projections_arr = np.logspace(0, 3, 25, dtype=int) -res = np.empty((n_seed, 25)) +n_seed = 20 +n_projections_arr = np.logspace(0, 3, 10, dtype=int) +res = np.empty((n_seed, 10)) # %% Compute statistics for seed in range(n_seed): diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py index d1f2042..ffedc6e 100644 --- a/examples/unbalanced-partial/plot_regpath.py +++ b/examples/unbalanced-partial/plot_regpath.py @@ -27,7 +27,7 @@ import matplotlib.animation as animation #%% parameters and data generation -n = 50 # nb samples +n = 20 # nb samples mu_s = np.array([-1, -1]) cov_s = np.array([[1, 0], [0, 1]]) @@ -63,7 +63,7 @@ pl.show() # ----------------------------------------------------------- #%% -final_gamma = 1e-8 +final_gamma = 1e-6 t, t_list, g_list = ot.regpath.regularization_path(a, b, M, reg=final_gamma, semi_relaxed=False) t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma, @@ -111,7 +111,7 @@ pl.show() # Animation of the regpath for UOT l2 # ----------------------------------- -nv = 100 +nv = 50 g_list_v = np.logspace(-.5, -2.5, nv) pl.figure(3) @@ -144,7 +144,7 @@ def _update_plot(iv): i = 0 _update_plot(i) -ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000) +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=100, repeat_delay=2000) ############################################################################## @@ -183,8 +183,8 @@ pl.show() # Animation of the regpath for semi-relaxed UOT l2 # ------------------------------------------------ -nv = 100 -g_list_v = np.logspace(2.5, -2, nv) +nv = 50 +g_list_v = np.logspace(2, -2, nv) pl.figure(5) @@ -216,4 +216,4 @@ def _update_plot(iv): i = 0 _update_plot(i) -ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000) +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=100, repeat_delay=2000) |