summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-06-08 22:32:18 +0200
committerGitHub <noreply@github.com>2021-06-08 22:32:18 +0200
commit982510eb5085a0edd7a00fb96a308854957d32bf (patch)
tree4f511d22b3b0997b44b2ed90a606c61abfff9f4e /examples
parent221e04b87ca48926c961051fc2bdac8e72aa32ad (diff)
[MRG] Update example GAN to avoid the 10 minute CircleCI limit (#258)
* shortened example GAN * pep8 and typo * awesome animation * small eror pep8 * add animation to doc * better timing animation * tune step
Diffstat (limited to 'examples')
-rw-r--r--examples/backends/plot_wass2_gan_torch.py40
1 files changed, 36 insertions, 4 deletions
diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py
index 8f50022..ca5b3c9 100644
--- a/examples/backends/plot_wass2_gan_torch.py
+++ b/examples/backends/plot_wass2_gan_torch.py
@@ -50,6 +50,7 @@ and Statistics (Vol. 108).
import numpy as np
import matplotlib.pyplot as pl
+import matplotlib.animation as animation
import torch
from torch import nn
import ot
@@ -112,10 +113,10 @@ class Generator(torch.nn.Module):
G = Generator()
-optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001)
+optimizer = torch.optim.RMSprop(G.parameters(), lr=0.00019, eps=1e-5)
# number of iteration and size of the batches
-n_iter = 500
+n_iter = 200 # set to 200 for doc build but 1000 is better ;)
size_batch = 500
# generate statis samples to see their trajectory along training
@@ -167,7 +168,7 @@ pl.xlabel("Iterations")
pl.figure(3, (10, 10))
-ivisu = [0, 10, 50, 100, 150, 200, 300, 400, 499]
+ivisu = [0, 10, 25, 50, 75, 125, 15, 175, 199]
for i in range(9):
pl.subplot(3, 3, i + 1)
@@ -180,6 +181,37 @@ for i in range(9):
pl.legend()
# %%
+# Animate trajectories of generated samples along iteration
+# -------------------------------------------------------
+
+pl.figure(4, (8, 8))
+
+
+def _update_plot(i):
+ pl.clf()
+ pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
+ pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+ pl.xticks(())
+ pl.yticks(())
+ pl.xlim((-1.5, 1.5))
+ pl.ylim((-1.5, 1.5))
+ pl.title('Iter. {}'.format(i))
+ return 1
+
+
+i = 0
+pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
+pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+pl.xticks(())
+pl.yticks(())
+pl.xlim((-1.5, 1.5))
+pl.ylim((-1.5, 1.5))
+pl.title('Iter. {}'.format(ivisu[i]))
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000)
+
+# %%
# Generate and visualize data
# ---------------------------
@@ -188,7 +220,7 @@ xd = get_data(size_batch)
xn = torch.randn(size_batch, 2)
x = G(xn).detach().numpy()
-pl.figure(4)
+pl.figure(5)
pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5)
pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
pl.title('Sources and Target distributions')