diff options
-rw-r--r-- | docs/source/conf.py | 4 | ||||
-rw-r--r-- | examples/backends/plot_wass2_gan_torch.py | 40 |
2 files changed, 39 insertions, 5 deletions
diff --git a/docs/source/conf.py b/docs/source/conf.py index 3a11798..9b5a719 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -337,7 +337,8 @@ texinfo_documents = [ intersphinx_mapping = {'python': ('https://docs.python.org/3', None), 'numpy': ('http://docs.scipy.org/doc/numpy/', None), 'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None), - 'matplotlib': ('http://matplotlib.org/', None)} + 'matplotlib': ('http://matplotlib.org/', None), + 'torch': ('https://pytorch.org/docs/stable/', None)} sphinx_gallery_conf = { 'examples_dirs': ['../../examples', '../../examples/da'], @@ -345,6 +346,7 @@ sphinx_gallery_conf = { 'backreferences_dir': 'gen_modules/backreferences', 'inspect_global_variables' : True, 'doc_module' : ('ot','numpy','scipy','pylab'), + 'matplotlib_animations': True, 'reference_url': { 'ot': None} } diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py index 8f50022..ca5b3c9 100644 --- a/examples/backends/plot_wass2_gan_torch.py +++ b/examples/backends/plot_wass2_gan_torch.py @@ -50,6 +50,7 @@ and Statistics (Vol. 108). import numpy as np import matplotlib.pyplot as pl +import matplotlib.animation as animation import torch from torch import nn import ot @@ -112,10 +113,10 @@ class Generator(torch.nn.Module): G = Generator() -optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001) +optimizer = torch.optim.RMSprop(G.parameters(), lr=0.00019, eps=1e-5) # number of iteration and size of the batches -n_iter = 500 +n_iter = 200 # set to 200 for doc build but 1000 is better ;) size_batch = 500 # generate statis samples to see their trajectory along training @@ -167,7 +168,7 @@ pl.xlabel("Iterations") pl.figure(3, (10, 10)) -ivisu = [0, 10, 50, 100, 150, 200, 300, 400, 499] +ivisu = [0, 10, 25, 50, 75, 125, 15, 175, 199] for i in range(9): pl.subplot(3, 3, i + 1) @@ -180,6 +181,37 @@ for i in range(9): pl.legend() # %% +# Animate trajectories of generated samples along iteration +# ------------------------------------------------------- + +pl.figure(4, (8, 8)) + + +def _update_plot(i): + pl.clf() + pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) + pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) + pl.xticks(()) + pl.yticks(()) + pl.xlim((-1.5, 1.5)) + pl.ylim((-1.5, 1.5)) + pl.title('Iter. {}'.format(i)) + return 1 + + +i = 0 +pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) +pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) +pl.xticks(()) +pl.yticks(()) +pl.xlim((-1.5, 1.5)) +pl.ylim((-1.5, 1.5)) +pl.title('Iter. {}'.format(ivisu[i])) + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000) + +# %% # Generate and visualize data # --------------------------- @@ -188,7 +220,7 @@ xd = get_data(size_batch) xn = torch.randn(size_batch, 2) x = G(xn).detach().numpy() -pl.figure(4) +pl.figure(5) pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5) pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) pl.title('Sources and Target distributions') |