summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2021-06-02 12:59:35 +0200
committerGitHub <noreply@github.com>2021-06-02 12:59:35 +0200
commitd693ac25988dd557cb1ee7fc96f3a656f7d4301c (patch)
tree9dc6d91b816b4b379684fbd12bf0e853f22374c0
parent184f8f4f7ac78f1dd7f653496d2753211a4e3426 (diff)
[WIP] Add Wasserstein GAN and debug memory leak (#254)
* add example and debug memory leak * print stuff * speedup gallery * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * test cells * proper header gan exmaple * cleanup sections * last changes ? Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
-rw-r--r--examples/backends/plot_wass2_gan_torch.py195
-rw-r--r--examples/domain-adaptation/plot_otda_color_images.py2
-rw-r--r--examples/domain-adaptation/plot_otda_mapping_colors_images.py2
-rw-r--r--ot/backend.py34
4 files changed, 220 insertions, 13 deletions
diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py
new file mode 100644
index 0000000..8f50022
--- /dev/null
+++ b/examples/backends/plot_wass2_gan_torch.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+r"""
+========================================
+Wasserstein 2 Minibatch GAN with PyTorch
+========================================
+
+In this example we train a Wasserstein GAN using Wasserstein 2 on minibatches
+as a distribution fitting term.
+
+We want to train a generator :math:`G_\theta` that generates realistic
+data from random noise drawn form a Gaussian :math:`\mu_n` distribution so
+that the data is indistinguishable from true data in the data distribution
+:math:`\mu_d`. To this end Wasserstein GAN [Arjovsky2017] aim at optimizing
+the parameters :math:`\theta` of the generator with the following
+optimization problem:
+
+.. math::
+ \min_{\theta} W(\mu_d,G_\theta\#\mu_n)
+
+
+In practice we do not have access to the full distribution :math:`\mu_d` but
+samples and we cannot compute the Wasserstein distance for lare dataset.
+[Arjovsky2017] proposed to approximate the dual potential of Wasserstein 1
+with a neural network recovering an optimization problem similar to GAN.
+In this example
+we will optimize the expectation of the Wasserstein distance over minibatches
+at each iterations as proposed in [Genevay2018]. Optimizing the Minibatches
+of the Wasserstein distance has been studied in[Fatras2019].
+
+[Arjovsky2017] Arjovsky, M., Chintala, S., & Bottou, L. (2017, July).
+Wasserstein generative adversarial networks. In International conference
+on machine learning (pp. 214-223). PMLR.
+
+[Genevay2018] Genevay, Aude, Gabriel Peyré, and Marco Cuturi. "Learning generative models
+with sinkhorn divergences." International Conference on Artificial Intelligence
+and Statistics. PMLR, 2018.
+
+[Fatras2019] Fatras, K., Zine, Y., Flamary, R., Gribonval, R., & Courty, N.
+(2020, June). Learning with minibatch Wasserstein: asymptotic and gradient
+properties. In the 23nd International Conference on Artificial Intelligence
+and Statistics (Vol. 108).
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+import numpy as np
+import matplotlib.pyplot as pl
+import torch
+from torch import nn
+import ot
+
+
+# %%
+# Data generation
+# ---------------
+
+torch.manual_seed(1)
+sigma = 0.1
+n_dims = 2
+n_features = 2
+
+
+def get_data(n_samples):
+ c = torch.rand(size=(n_samples, 1))
+ angle = c * 2 * np.pi
+ x = torch.cat((torch.cos(angle), torch.sin(angle)), 1)
+ x += torch.randn(n_samples, 2) * sigma
+ return x
+
+
+# %%
+# Plot data
+# ---------
+
+# plot the distributions
+x = get_data(500)
+pl.figure(1)
+pl.scatter(x[:, 0], x[:, 1], label='Data samples from $\mu_d$', alpha=0.5)
+pl.title('Data distribution')
+pl.legend()
+
+
+# %%
+# Generator Model
+# ---------------
+
+# define the MLP model
+class Generator(torch.nn.Module):
+ def __init__(self):
+ super(Generator, self).__init__()
+ self.fc1 = nn.Linear(n_features, 200)
+ self.fc2 = nn.Linear(200, 500)
+ self.fc3 = nn.Linear(500, n_dims)
+ self.relu = torch.nn.ReLU() # instead of Heaviside step fn
+
+ def forward(self, x):
+ output = self.fc1(x)
+ output = self.relu(output) # instead of Heaviside step fn
+ output = self.fc2(output)
+ output = self.relu(output)
+ output = self.fc3(output)
+ return output
+
+# %%
+# Training the model
+# ------------------
+
+
+G = Generator()
+optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001)
+
+# number of iteration and size of the batches
+n_iter = 500
+size_batch = 500
+
+# generate statis samples to see their trajectory along training
+n_visu = 100
+xnvisu = torch.randn(n_visu, n_features)
+xvisu = torch.zeros(n_iter, n_visu, n_dims)
+
+ab = torch.ones(size_batch) / size_batch
+losses = []
+
+
+for i in range(n_iter):
+
+ # generate noise samples
+ xn = torch.randn(size_batch, n_features)
+
+ # generate data samples
+ xd = get_data(size_batch)
+
+ # generate sample along iterations
+ xvisu[i, :, :] = G(xnvisu).detach()
+
+ # generate smaples and compte distance matrix
+ xg = G(xn)
+ M = ot.dist(xg, xd)
+
+ loss = ot.emd2(ab, ab, M)
+ losses.append(float(loss.detach()))
+
+ if i % 10 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+ loss.backward()
+ optimizer.step()
+
+ del M
+
+pl.figure(2)
+pl.semilogy(losses)
+pl.grid()
+pl.title('Wasserstein distance')
+pl.xlabel("Iterations")
+
+
+# %%
+# Plot trajectories of generated samples along iterations
+# -------------------------------------------------------
+
+
+pl.figure(3, (10, 10))
+
+ivisu = [0, 10, 50, 100, 150, 200, 300, 400, 499]
+
+for i in range(9):
+ pl.subplot(3, 3, i + 1)
+ pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
+ pl.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+ pl.xticks(())
+ pl.yticks(())
+ pl.title('Iter. {}'.format(ivisu[i]))
+ if i == 0:
+ pl.legend()
+
+# %%
+# Generate and visualize data
+# ---------------------------
+
+size_batch = 500
+xd = get_data(size_batch)
+xn = torch.randn(size_batch, 2)
+x = G(xn).detach().numpy()
+
+pl.figure(4)
+pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5)
+pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
+pl.title('Sources and Target distributions')
+pl.legend()
diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py
index d70f1fc..6218b13 100644
--- a/examples/domain-adaptation/plot_otda_color_images.py
+++ b/examples/domain-adaptation/plot_otda_color_images.py
@@ -53,7 +53,7 @@ X1 = im2mat(I1)
X2 = im2mat(I2)
# training samples
-nb = 1000
+nb = 500
idx1 = r.randint(X1.shape[0], size=(nb,))
idx2 = r.randint(X2.shape[0], size=(nb,))
diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py
index aa41d22..72010a6 100644
--- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py
+++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py
@@ -56,7 +56,7 @@ X1 = im2mat(I1)
X2 = im2mat(I2)
# training samples
-nb = 1000
+nb = 500
idx1 = r.randint(X1.shape[0], size=(nb,))
idx2 = r.randint(X2.shape[0], size=(nb,))
diff --git a/ot/backend.py b/ot/backend.py
index d68f5cf..8f46900 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -389,6 +389,26 @@ class TorchBackend(Backend):
__name__ = 'torch'
__type__ = torch_type
+ def __init__(self):
+
+ from torch.autograd import Function
+
+ # define a function that takes inputs val and grads
+ # ad returns a val tensor with proper gradients
+ class ValFunction(Function):
+
+ @staticmethod
+ def forward(ctx, val, grads, *inputs):
+ ctx.grads = grads
+ return val
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # the gradients are grad
+ return (None, None) + ctx.grads
+
+ self.ValFunction = ValFunction
+
def to_numpy(self, a):
return a.cpu().detach().numpy()
@@ -399,20 +419,12 @@ class TorchBackend(Backend):
return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device)
def set_gradients(self, val, inputs, grads):
- from torch.autograd import Function
- # define a function that takes inputs and return val
- class ValFunction(Function):
- @staticmethod
- def forward(ctx, *inputs):
- return val
+ Func = self.ValFunction()
- @staticmethod
- def backward(ctx, grad_output):
- # the gradients are grad
- return grads
+ res = Func.apply(val, grads, *inputs)
- return ValFunction.apply(*inputs)
+ return res
def zeros(self, shape, type_as=None):
if type_as is None: