diff options
Diffstat (limited to 'examples/backends/plot_ssw_unif_torch.py')
-rw-r--r-- | examples/backends/plot_ssw_unif_torch.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py index afe3fa6..7459cf6 100644 --- a/examples/backends/plot_ssw_unif_torch.py +++ b/examples/backends/plot_ssw_unif_torch.py @@ -35,7 +35,7 @@ import ot torch.manual_seed(1) -N = 1000 +N = 500 x0 = torch.rand(N, 3) x0 = F.normalize(x0, dim=-1) @@ -72,8 +72,8 @@ ax.legend() x = x0.clone() x.requires_grad_(True) -n_iter = 500 -lr = 100 +n_iter = 100 +lr = 150 losses = [] xvisu = torch.zeros(n_iter, N, 3) @@ -82,7 +82,7 @@ for i in range(n_iter): sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500) grad_x = torch.autograd.grad(sw, x)[0] - x = x - lr * grad_x + x = x - lr * grad_x / np.sqrt(i / 10 + 1) x = F.normalize(x, p=2, dim=1) losses.append(sw.item()) @@ -102,7 +102,7 @@ pl.xlabel("Iterations") # Plot trajectories of generated samples along iterations # ------------------------------------------------------- -ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499] +ivisu = [0, 10, 20, 30, 40, 50, 60, 70, 80] fig = pl.figure(3, (10, 10)) for i in range(9): @@ -149,5 +149,5 @@ ax.set_ylim((-1.5, 1.5)) ax.set_title('Iter. {}'.format(ivisu[i])) -ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000) +ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=200, repeat_delay=2000) # %% |