From 0afd84d744a472903d427e3c7ae32e55fdd7b9a7 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 4 Apr 2022 10:23:04 +0200 Subject: [WIP] Add backend dual loss and plan computation for stochastic optimization or regularized OT (#360) * add losses and plan computations and exmaple for dual oiptimization * pep8 * add nice exmaple * update awesome example stochasti dual * add all tests * pep8 + speedup exmaple * add release info --- examples/backends/plot_dual_ot_pytorch.py | 168 ++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 examples/backends/plot_dual_ot_pytorch.py (limited to 'examples/backends/plot_dual_ot_pytorch.py') diff --git a/examples/backends/plot_dual_ot_pytorch.py b/examples/backends/plot_dual_ot_pytorch.py new file mode 100644 index 0000000..d3f7a66 --- /dev/null +++ b/examples/backends/plot_dual_ot_pytorch.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +r""" +====================================================================== +Dual OT solvers for entropic and quadratic regularized OT with Pytorch +====================================================================== + + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as pl +import torch +import ot +import ot.plot + +# %% +# Data generation +# --------------- + +torch.manual_seed(1) + +n_source_samples = 100 +n_target_samples = 100 +theta = 2 * np.pi / 20 +noise_level = 0.1 + +Xs, ys = ot.datasets.make_data_classif( + 'gaussrot', n_source_samples, nz=noise_level) +Xt, yt = ot.datasets.make_data_classif( + 'gaussrot', n_target_samples, theta=theta, nz=noise_level) + +# one of the target mode changes its variance (no linear mapping) +Xt[yt == 2] *= 3 +Xt = Xt + 4 + + +# %% +# Plot data +# --------- + +pl.figure(1, (10, 5)) +pl.clf() +pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples') +pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +# %% +# Convert data to torch tensors +# ----------------------------- + +xs = torch.tensor(Xs) +xt = torch.tensor(Xt) + +# %% +# Estimating dual variables for entropic OT +# ----------------------------------------- + +u = torch.randn(n_source_samples, requires_grad=True) +v = torch.randn(n_source_samples, requires_grad=True) + +reg = 0.5 + +optimizer = torch.optim.Adam([u, v], lr=1) + +# number of iteration +n_iter = 200 + + +losses = [] + +for i in range(n_iter): + + # generate noise samples + + # minus because we maximize te dual loss + loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg) + losses.append(float(loss.detach())) + + if i % 10 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + +pl.figure(2) +pl.plot(losses) +pl.grid() +pl.title('Dual objective (negative)') +pl.xlabel("Iterations") + +Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg) + +# %% +# Plot teh estimated entropic OT plan +# ----------------------------------- + +pl.figure(3, (10, 5)) +pl.clf() +ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1) +pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2) +pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2) +pl.legend(loc=0) +pl.title('Source and target distributions') + + +# %% +# Estimating dual variables for quadratic OT +# ----------------------------------------- + +u = torch.randn(n_source_samples, requires_grad=True) +v = torch.randn(n_source_samples, requires_grad=True) + +reg = 0.01 + +optimizer = torch.optim.Adam([u, v], lr=1) + +# number of iteration +n_iter = 200 + + +losses = [] + + +for i in range(n_iter): + + # generate noise samples + + # minus because we maximize te dual loss + loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg) + losses.append(float(loss.detach())) + + if i % 10 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + +pl.figure(4) +pl.plot(losses) +pl.grid() +pl.title('Dual objective (negative)') +pl.xlabel("Iterations") + +Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg) + + +# %% +# Plot the estimated quadratic OT plan +# ----------------------------------- + +pl.figure(5, (10, 5)) +pl.clf() +ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1) +pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2) +pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2) +pl.legend(loc=0) +pl.title('OT plan with quadratic regularization') -- cgit v1.2.3