summaryrefslogtreecommitdiff
path: root/examples/backends/plot_ssw_unif_torch.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2023-06-09 20:26:52 +0200
committerGitHub <noreply@github.com>2023-06-09 20:26:52 +0200
commit6c1e1f3e064165d37e22acc866c6fff56e3ab6ad (patch)
tree354aa9c554a9e7490c93bd2ac579675f1b933329 /examples/backends/plot_ssw_unif_torch.py
parent5faa4fbdb1a64351a42d31dd6f54f0402c29c405 (diff)
[MRG] Update tests and documentation (#484)
* remove old macos and windows tets update requirements * speedup ssw and continuaous ot exmaples * speedup regpath and variane * speedup conv 2d example + continuous stick * speedup regpath
Diffstat (limited to 'examples/backends/plot_ssw_unif_torch.py')
-rw-r--r--examples/backends/plot_ssw_unif_torch.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py
index afe3fa6..7459cf6 100644
--- a/examples/backends/plot_ssw_unif_torch.py
+++ b/examples/backends/plot_ssw_unif_torch.py
@@ -35,7 +35,7 @@ import ot
torch.manual_seed(1)
-N = 1000
+N = 500
x0 = torch.rand(N, 3)
x0 = F.normalize(x0, dim=-1)
@@ -72,8 +72,8 @@ ax.legend()
x = x0.clone()
x.requires_grad_(True)
-n_iter = 500
-lr = 100
+n_iter = 100
+lr = 150
losses = []
xvisu = torch.zeros(n_iter, N, 3)
@@ -82,7 +82,7 @@ for i in range(n_iter):
sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500)
grad_x = torch.autograd.grad(sw, x)[0]
- x = x - lr * grad_x
+ x = x - lr * grad_x / np.sqrt(i / 10 + 1)
x = F.normalize(x, p=2, dim=1)
losses.append(sw.item())
@@ -102,7 +102,7 @@ pl.xlabel("Iterations")
# Plot trajectories of generated samples along iterations
# -------------------------------------------------------
-ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499]
+ivisu = [0, 10, 20, 30, 40, 50, 60, 70, 80]
fig = pl.figure(3, (10, 10))
for i in range(9):
@@ -149,5 +149,5 @@ ax.set_ylim((-1.5, 1.5))
ax.set_title('Iter. {}'.format(ivisu[i]))
-ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000)
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=200, repeat_delay=2000)
# %%