summaryrefslogtreecommitdiff
path: root/examples/backends/plot_ssw_unif_torch.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/backends/plot_ssw_unif_torch.py')
-rw-r--r--examples/backends/plot_ssw_unif_torch.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py
index afe3fa6..7459cf6 100644
--- a/examples/backends/plot_ssw_unif_torch.py
+++ b/examples/backends/plot_ssw_unif_torch.py
@@ -35,7 +35,7 @@ import ot
torch.manual_seed(1)
-N = 1000
+N = 500
x0 = torch.rand(N, 3)
x0 = F.normalize(x0, dim=-1)
@@ -72,8 +72,8 @@ ax.legend()
x = x0.clone()
x.requires_grad_(True)
-n_iter = 500
-lr = 100
+n_iter = 100
+lr = 150
losses = []
xvisu = torch.zeros(n_iter, N, 3)
@@ -82,7 +82,7 @@ for i in range(n_iter):
sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500)
grad_x = torch.autograd.grad(sw, x)[0]
- x = x - lr * grad_x
+ x = x - lr * grad_x / np.sqrt(i / 10 + 1)
x = F.normalize(x, p=2, dim=1)
losses.append(sw.item())
@@ -102,7 +102,7 @@ pl.xlabel("Iterations")
# Plot trajectories of generated samples along iterations
# -------------------------------------------------------
-ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499]
+ivisu = [0, 10, 20, 30, 40, 50, 60, 70, 80]
fig = pl.figure(3, (10, 10))
for i in range(9):
@@ -149,5 +149,5 @@ ax.set_ylim((-1.5, 1.5))
ax.set_title('Iter. {}'.format(ivisu[i]))
-ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000)
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=200, repeat_delay=2000)
# %%