From d693ac25988dd557cb1ee7fc96f3a656f7d4301c Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 2 Jun 2021 12:59:35 +0200 Subject: [WIP] Add Wasserstein GAN and debug memory leak (#254) * add example and debug memory leak * print stuff * speedup gallery * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * test cells * proper header gan exmaple * cleanup sections * last changes ? Co-authored-by: Alexandre Gramfort --- examples/backends/plot_wass2_gan_torch.py | 195 +++++++++++++++++++++ .../domain-adaptation/plot_otda_color_images.py | 2 +- .../plot_otda_mapping_colors_images.py | 2 +- 3 files changed, 197 insertions(+), 2 deletions(-) create mode 100644 examples/backends/plot_wass2_gan_torch.py (limited to 'examples') diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py new file mode 100644 index 0000000..8f50022 --- /dev/null +++ b/examples/backends/plot_wass2_gan_torch.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +r""" +======================================== +Wasserstein 2 Minibatch GAN with PyTorch +======================================== + +In this example we train a Wasserstein GAN using Wasserstein 2 on minibatches +as a distribution fitting term. + +We want to train a generator :math:`G_\theta` that generates realistic +data from random noise drawn form a Gaussian :math:`\mu_n` distribution so +that the data is indistinguishable from true data in the data distribution +:math:`\mu_d`. To this end Wasserstein GAN [Arjovsky2017] aim at optimizing +the parameters :math:`\theta` of the generator with the following +optimization problem: + +.. math:: + \min_{\theta} W(\mu_d,G_\theta\#\mu_n) + + +In practice we do not have access to the full distribution :math:`\mu_d` but +samples and we cannot compute the Wasserstein distance for lare dataset. +[Arjovsky2017] proposed to approximate the dual potential of Wasserstein 1 +with a neural network recovering an optimization problem similar to GAN. +In this example +we will optimize the expectation of the Wasserstein distance over minibatches +at each iterations as proposed in [Genevay2018]. Optimizing the Minibatches +of the Wasserstein distance has been studied in[Fatras2019]. + +[Arjovsky2017] Arjovsky, M., Chintala, S., & Bottou, L. (2017, July). +Wasserstein generative adversarial networks. In International conference +on machine learning (pp. 214-223). PMLR. + +[Genevay2018] Genevay, Aude, Gabriel Peyré, and Marco Cuturi. "Learning generative models +with sinkhorn divergences." International Conference on Artificial Intelligence +and Statistics. PMLR, 2018. + +[Fatras2019] Fatras, K., Zine, Y., Flamary, R., Gribonval, R., & Courty, N. +(2020, June). Learning with minibatch Wasserstein: asymptotic and gradient +properties. In the 23nd International Conference on Artificial Intelligence +and Statistics (Vol. 108). + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as pl +import torch +from torch import nn +import ot + + +# %% +# Data generation +# --------------- + +torch.manual_seed(1) +sigma = 0.1 +n_dims = 2 +n_features = 2 + + +def get_data(n_samples): + c = torch.rand(size=(n_samples, 1)) + angle = c * 2 * np.pi + x = torch.cat((torch.cos(angle), torch.sin(angle)), 1) + x += torch.randn(n_samples, 2) * sigma + return x + + +# %% +# Plot data +# --------- + +# plot the distributions +x = get_data(500) +pl.figure(1) +pl.scatter(x[:, 0], x[:, 1], label='Data samples from $\mu_d$', alpha=0.5) +pl.title('Data distribution') +pl.legend() + + +# %% +# Generator Model +# --------------- + +# define the MLP model +class Generator(torch.nn.Module): + def __init__(self): + super(Generator, self).__init__() + self.fc1 = nn.Linear(n_features, 200) + self.fc2 = nn.Linear(200, 500) + self.fc3 = nn.Linear(500, n_dims) + self.relu = torch.nn.ReLU() # instead of Heaviside step fn + + def forward(self, x): + output = self.fc1(x) + output = self.relu(output) # instead of Heaviside step fn + output = self.fc2(output) + output = self.relu(output) + output = self.fc3(output) + return output + +# %% +# Training the model +# ------------------ + + +G = Generator() +optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001) + +# number of iteration and size of the batches +n_iter = 500 +size_batch = 500 + +# generate statis samples to see their trajectory along training +n_visu = 100 +xnvisu = torch.randn(n_visu, n_features) +xvisu = torch.zeros(n_iter, n_visu, n_dims) + +ab = torch.ones(size_batch) / size_batch +losses = [] + + +for i in range(n_iter): + + # generate noise samples + xn = torch.randn(size_batch, n_features) + + # generate data samples + xd = get_data(size_batch) + + # generate sample along iterations + xvisu[i, :, :] = G(xnvisu).detach() + + # generate smaples and compte distance matrix + xg = G(xn) + M = ot.dist(xg, xd) + + loss = ot.emd2(ab, ab, M) + losses.append(float(loss.detach())) + + if i % 10 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + + loss.backward() + optimizer.step() + + del M + +pl.figure(2) +pl.semilogy(losses) +pl.grid() +pl.title('Wasserstein distance') +pl.xlabel("Iterations") + + +# %% +# Plot trajectories of generated samples along iterations +# ------------------------------------------------------- + + +pl.figure(3, (10, 10)) + +ivisu = [0, 10, 50, 100, 150, 200, 300, 400, 499] + +for i in range(9): + pl.subplot(3, 3, i + 1) + pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) + pl.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) + pl.xticks(()) + pl.yticks(()) + pl.title('Iter. {}'.format(ivisu[i])) + if i == 0: + pl.legend() + +# %% +# Generate and visualize data +# --------------------------- + +size_batch = 500 +xd = get_data(size_batch) +xn = torch.randn(size_batch, 2) +x = G(xn).detach().numpy() + +pl.figure(4) +pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5) +pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) +pl.title('Sources and Target distributions') +pl.legend() diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py index d70f1fc..6218b13 100644 --- a/examples/domain-adaptation/plot_otda_color_images.py +++ b/examples/domain-adaptation/plot_otda_color_images.py @@ -53,7 +53,7 @@ X1 = im2mat(I1) X2 = im2mat(I2) # training samples -nb = 1000 +nb = 500 idx1 = r.randint(X1.shape[0], size=(nb,)) idx2 = r.randint(X2.shape[0], size=(nb,)) diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py index aa41d22..72010a6 100644 --- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py +++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py @@ -56,7 +56,7 @@ X1 = im2mat(I1) X2 = im2mat(I2) # training samples -nb = 1000 +nb = 500 idx1 = r.randint(X1.shape[0], size=(nb,)) idx2 = r.randint(X2.shape[0], size=(nb,)) -- cgit v1.2.3