summaryrefslogtreecommitdiff
path: root/examples/backends/plot_wass2_gan_torch.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/backends/plot_wass2_gan_torch.py')
-rw-r--r--examples/backends/plot_wass2_gan_torch.py40
1 files changed, 36 insertions, 4 deletions
diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py
index 8f50022..ca5b3c9 100644
--- a/examples/backends/plot_wass2_gan_torch.py
+++ b/examples/backends/plot_wass2_gan_torch.py
@@ -50,6 +50,7 @@ and Statistics (Vol. 108).
import numpy as np
import matplotlib.pyplot as pl
+import matplotlib.animation as animation
import torch
from torch import nn
import ot
@@ -112,10 +113,10 @@ class Generator(torch.nn.Module):
G = Generator()
-optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001)
+optimizer = torch.optim.RMSprop(G.parameters(), lr=0.00019, eps=1e-5)
# number of iteration and size of the batches
-n_iter = 500
+n_iter = 200 # set to 200 for doc build but 1000 is better ;)
size_batch = 500
# generate statis samples to see their trajectory along training
@@ -167,7 +168,7 @@ pl.xlabel("Iterations")
pl.figure(3, (10, 10))
-ivisu = [0, 10, 50, 100, 150, 200, 300, 400, 499]
+ivisu = [0, 10, 25, 50, 75, 125, 15, 175, 199]
for i in range(9):
pl.subplot(3, 3, i + 1)
@@ -180,6 +181,37 @@ for i in range(9):
pl.legend()
# %%
+# Animate trajectories of generated samples along iteration
+# -------------------------------------------------------
+
+pl.figure(4, (8, 8))
+
+
+def _update_plot(i):
+ pl.clf()
+ pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
+ pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+ pl.xticks(())
+ pl.yticks(())
+ pl.xlim((-1.5, 1.5))
+ pl.ylim((-1.5, 1.5))
+ pl.title('Iter. {}'.format(i))
+ return 1
+
+
+i = 0
+pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
+pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+pl.xticks(())
+pl.yticks(())
+pl.xlim((-1.5, 1.5))
+pl.ylim((-1.5, 1.5))
+pl.title('Iter. {}'.format(ivisu[i]))
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000)
+
+# %%
# Generate and visualize data
# ---------------------------
@@ -188,7 +220,7 @@ xd = get_data(size_batch)
xn = torch.randn(size_batch, 2)
x = G(xn).detach().numpy()
-pl.figure(4)
+pl.figure(5)
pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5)
pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
pl.title('Sources and Target distributions')