diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-06-08 22:32:18 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-08 22:32:18 +0200 |
commit | 982510eb5085a0edd7a00fb96a308854957d32bf (patch) | |
tree | 4f511d22b3b0997b44b2ed90a606c61abfff9f4e /examples/backends | |
parent | 221e04b87ca48926c961051fc2bdac8e72aa32ad (diff) |
[MRG] Update example GAN to avoid the 10 minute CircleCI limit (#258)
* shortened example GAN
* pep8 and typo
* awesome animation
* small eror pep8
* add animation to doc
* better timing animation
* tune step
Diffstat (limited to 'examples/backends')
-rw-r--r-- | examples/backends/plot_wass2_gan_torch.py | 40 |
1 files changed, 36 insertions, 4 deletions
diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py index 8f50022..ca5b3c9 100644 --- a/examples/backends/plot_wass2_gan_torch.py +++ b/examples/backends/plot_wass2_gan_torch.py @@ -50,6 +50,7 @@ and Statistics (Vol. 108). import numpy as np import matplotlib.pyplot as pl +import matplotlib.animation as animation import torch from torch import nn import ot @@ -112,10 +113,10 @@ class Generator(torch.nn.Module): G = Generator() -optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001) +optimizer = torch.optim.RMSprop(G.parameters(), lr=0.00019, eps=1e-5) # number of iteration and size of the batches -n_iter = 500 +n_iter = 200 # set to 200 for doc build but 1000 is better ;) size_batch = 500 # generate statis samples to see their trajectory along training @@ -167,7 +168,7 @@ pl.xlabel("Iterations") pl.figure(3, (10, 10)) -ivisu = [0, 10, 50, 100, 150, 200, 300, 400, 499] +ivisu = [0, 10, 25, 50, 75, 125, 15, 175, 199] for i in range(9): pl.subplot(3, 3, i + 1) @@ -180,6 +181,37 @@ for i in range(9): pl.legend() # %% +# Animate trajectories of generated samples along iteration +# ------------------------------------------------------- + +pl.figure(4, (8, 8)) + + +def _update_plot(i): + pl.clf() + pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) + pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) + pl.xticks(()) + pl.yticks(()) + pl.xlim((-1.5, 1.5)) + pl.ylim((-1.5, 1.5)) + pl.title('Iter. {}'.format(i)) + return 1 + + +i = 0 +pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) +pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) +pl.xticks(()) +pl.yticks(()) +pl.xlim((-1.5, 1.5)) +pl.ylim((-1.5, 1.5)) +pl.title('Iter. {}'.format(ivisu[i])) + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000) + +# %% # Generate and visualize data # --------------------------- @@ -188,7 +220,7 @@ xd = get_data(size_batch) xn = torch.randn(size_batch, 2) x = G(xn).detach().numpy() -pl.figure(4) +pl.figure(5) pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5) pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) pl.title('Sources and Target distributions') |