summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/source/conf.py4
-rw-r--r--examples/backends/plot_wass2_gan_torch.py40
2 files changed, 39 insertions, 5 deletions
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 3a11798..9b5a719 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -337,7 +337,8 @@ texinfo_documents = [
intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None),
- 'matplotlib': ('http://matplotlib.org/', None)}
+ 'matplotlib': ('http://matplotlib.org/', None),
+ 'torch': ('https://pytorch.org/docs/stable/', None)}
sphinx_gallery_conf = {
'examples_dirs': ['../../examples', '../../examples/da'],
@@ -345,6 +346,7 @@ sphinx_gallery_conf = {
'backreferences_dir': 'gen_modules/backreferences',
'inspect_global_variables' : True,
'doc_module' : ('ot','numpy','scipy','pylab'),
+ 'matplotlib_animations': True,
'reference_url': {
'ot': None}
}
diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py
index 8f50022..ca5b3c9 100644
--- a/examples/backends/plot_wass2_gan_torch.py
+++ b/examples/backends/plot_wass2_gan_torch.py
@@ -50,6 +50,7 @@ and Statistics (Vol. 108).
import numpy as np
import matplotlib.pyplot as pl
+import matplotlib.animation as animation
import torch
from torch import nn
import ot
@@ -112,10 +113,10 @@ class Generator(torch.nn.Module):
G = Generator()
-optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001)
+optimizer = torch.optim.RMSprop(G.parameters(), lr=0.00019, eps=1e-5)
# number of iteration and size of the batches
-n_iter = 500
+n_iter = 200 # set to 200 for doc build but 1000 is better ;)
size_batch = 500
# generate statis samples to see their trajectory along training
@@ -167,7 +168,7 @@ pl.xlabel("Iterations")
pl.figure(3, (10, 10))
-ivisu = [0, 10, 50, 100, 150, 200, 300, 400, 499]
+ivisu = [0, 10, 25, 50, 75, 125, 15, 175, 199]
for i in range(9):
pl.subplot(3, 3, i + 1)
@@ -180,6 +181,37 @@ for i in range(9):
pl.legend()
# %%
+# Animate trajectories of generated samples along iteration
+# -------------------------------------------------------
+
+pl.figure(4, (8, 8))
+
+
+def _update_plot(i):
+ pl.clf()
+ pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
+ pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+ pl.xticks(())
+ pl.yticks(())
+ pl.xlim((-1.5, 1.5))
+ pl.ylim((-1.5, 1.5))
+ pl.title('Iter. {}'.format(i))
+ return 1
+
+
+i = 0
+pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
+pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+pl.xticks(())
+pl.yticks(())
+pl.xlim((-1.5, 1.5))
+pl.ylim((-1.5, 1.5))
+pl.title('Iter. {}'.format(ivisu[i]))
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000)
+
+# %%
# Generate and visualize data
# ---------------------------
@@ -188,7 +220,7 @@ xd = get_data(size_batch)
xn = torch.randn(size_batch, 2)
x = G(xn).detach().numpy()
-pl.figure(4)
+pl.figure(5)
pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5)
pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
pl.title('Sources and Target distributions')