diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2023-06-09 20:26:52 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-09 20:26:52 +0200 |
commit | 6c1e1f3e064165d37e22acc866c6fff56e3ab6ad (patch) | |
tree | 354aa9c554a9e7490c93bd2ac579675f1b933329 /examples/backends/plot_ssw_unif_torch.py | |
parent | 5faa4fbdb1a64351a42d31dd6f54f0402c29c405 (diff) |
[MRG] Update tests and documentation (#484)
* remove old macos and windows tets update requirements
* speedup ssw and continuaous ot exmaples
* speedup regpath and variane
* speedup conv 2d example + continuous stick
* speedup regpath
Diffstat (limited to 'examples/backends/plot_ssw_unif_torch.py')
-rw-r--r-- | examples/backends/plot_ssw_unif_torch.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py index afe3fa6..7459cf6 100644 --- a/examples/backends/plot_ssw_unif_torch.py +++ b/examples/backends/plot_ssw_unif_torch.py @@ -35,7 +35,7 @@ import ot torch.manual_seed(1) -N = 1000 +N = 500 x0 = torch.rand(N, 3) x0 = F.normalize(x0, dim=-1) @@ -72,8 +72,8 @@ ax.legend() x = x0.clone() x.requires_grad_(True) -n_iter = 500 -lr = 100 +n_iter = 100 +lr = 150 losses = [] xvisu = torch.zeros(n_iter, N, 3) @@ -82,7 +82,7 @@ for i in range(n_iter): sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500) grad_x = torch.autograd.grad(sw, x)[0] - x = x - lr * grad_x + x = x - lr * grad_x / np.sqrt(i / 10 + 1) x = F.normalize(x, p=2, dim=1) losses.append(sw.item()) @@ -102,7 +102,7 @@ pl.xlabel("Iterations") # Plot trajectories of generated samples along iterations # ------------------------------------------------------- -ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499] +ivisu = [0, 10, 20, 30, 40, 50, 60, 70, 80] fig = pl.figure(3, (10, 10)) for i in range(9): @@ -149,5 +149,5 @@ ax.set_ylim((-1.5, 1.5)) ax.set_title('Iter. {}'.format(ivisu[i])) -ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000) +ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=200, repeat_delay=2000) # %% |