From 982510eb5085a0edd7a00fb96a308854957d32bf Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Tue, 8 Jun 2021 22:32:18 +0200 Subject: [MRG] Update example GAN to avoid the 10 minute CircleCI limit (#258) * shortened example GAN * pep8 and typo * awesome animation * small eror pep8 * add animation to doc * better timing animation * tune step --- examples/backends/plot_wass2_gan_torch.py | 40 +++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) (limited to 'examples') diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py index 8f50022..ca5b3c9 100644 --- a/examples/backends/plot_wass2_gan_torch.py +++ b/examples/backends/plot_wass2_gan_torch.py @@ -50,6 +50,7 @@ and Statistics (Vol. 108). import numpy as np import matplotlib.pyplot as pl +import matplotlib.animation as animation import torch from torch import nn import ot @@ -112,10 +113,10 @@ class Generator(torch.nn.Module): G = Generator() -optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001) +optimizer = torch.optim.RMSprop(G.parameters(), lr=0.00019, eps=1e-5) # number of iteration and size of the batches -n_iter = 500 +n_iter = 200 # set to 200 for doc build but 1000 is better ;) size_batch = 500 # generate statis samples to see their trajectory along training @@ -167,7 +168,7 @@ pl.xlabel("Iterations") pl.figure(3, (10, 10)) -ivisu = [0, 10, 50, 100, 150, 200, 300, 400, 499] +ivisu = [0, 10, 25, 50, 75, 125, 15, 175, 199] for i in range(9): pl.subplot(3, 3, i + 1) @@ -179,6 +180,37 @@ for i in range(9): if i == 0: pl.legend() +# %% +# Animate trajectories of generated samples along iteration +# ------------------------------------------------------- + +pl.figure(4, (8, 8)) + + +def _update_plot(i): + pl.clf() + pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) + pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) + pl.xticks(()) + pl.yticks(()) + pl.xlim((-1.5, 1.5)) + pl.ylim((-1.5, 1.5)) + pl.title('Iter. {}'.format(i)) + return 1 + + +i = 0 +pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) +pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) +pl.xticks(()) +pl.yticks(()) +pl.xlim((-1.5, 1.5)) +pl.ylim((-1.5, 1.5)) +pl.title('Iter. {}'.format(ivisu[i])) + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000) + # %% # Generate and visualize data # --------------------------- @@ -188,7 +220,7 @@ xd = get_data(size_batch) xn = torch.randn(size_batch, 2) x = G(xn).detach().numpy() -pl.figure(4) +pl.figure(5) pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5) pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) pl.title('Sources and Target distributions') -- cgit v1.2.3