diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-06-01 10:10:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-01 10:10:54 +0200 |
commit | 184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch) | |
tree | 483a7274c91030fd644de49b03a5fad04af9deba /examples/backends | |
parent | 1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff) |
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends
* stat sets on functions
* proper import
* install recent torch on windows
* install recent torch on windows
* now testing all functions in backedn
* add jax backedn
* clenaup windowds
* proper convert for jax backedn
* pep8
* try again windows tests
* test jax conversion
* try proper widows tests
* emd fuction ses backedn
* better test partial OT
* proper tests to_numpy and teplate Backend
* pep8
* pep8 x2
* feaking sinkhorn works with torch
* sinkhorn2 compatible
* working ot.emd2
* important detach
* it should work
* jax autodiff emd
* pep8
* no tast same for jax
* new independat tests per backedn
* freaking pep8
* add tests for gradients
* deprecate ot.gpu
* worging dist function
* working dist
* dist done in backedn
* not in
* remove indexing
* change accuacy for jax
* first pull backend
* projection simplex
* projection simplex
* projection simplex
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8
* add backedn discusion to quickstart guide
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8 + better doc
* proper links
* corect doctest
* big debug documentation
* doctest again
* doctest again bis
* doctest again ter (last one or i kill myself)
* backend test + doc proj simplex
* correction test_utils
* correction test_utils
* correction cumsum
* correction flip
* correction flip v2
* more debug
* more debug
* more debug + pep8
* pep8
* argh
* proj_simplex
* backedn works for sort
* proj simplex
* jax sucks
* update doc
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/lp/__init__.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* begin comment alex
* comment alex part 2
* optimize test gromov
* proj_simplex on vectors
* add awesome gradient decsnt example on the weights
* pep98 of course
* proof read example by alex
* pep8 again
* encoding oos in translation
* correct legend
Co-authored-by: Nicolas Courty <ncourty@irisa.fr>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'examples/backends')
-rw-r--r-- | examples/backends/README.txt | 4 | ||||
-rw-r--r-- | examples/backends/plot_unmix_optim_torch.py | 161 |
2 files changed, 165 insertions, 0 deletions
diff --git a/examples/backends/README.txt b/examples/backends/README.txt new file mode 100644 index 0000000..3ee0e27 --- /dev/null +++ b/examples/backends/README.txt @@ -0,0 +1,4 @@ + + +POT backend examples +--------------------
\ No newline at end of file diff --git a/examples/backends/plot_unmix_optim_torch.py b/examples/backends/plot_unmix_optim_torch.py new file mode 100644 index 0000000..9ae66e9 --- /dev/null +++ b/examples/backends/plot_unmix_optim_torch.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +r""" +================================= +Wasserstein unmixing with PyTorch +================================= + +In this example we estimate mixing parameters from distributions that minimize +the Wasserstein distance. In other words we suppose that a target +distribution :math:`\mu^t` can be expressed as a weighted sum of source +distributions :math:`\mu^s_k` with the following model: + +.. math:: + \mu^t = \sum_{k=1}^K w_k\mu^s_k + +where :math:`\mathbf{w}` is a vector of size :math:`K` and belongs in the +distribution simplex :math:`\Delta_K`. + +In order to estimate this weight vector we propose to optimize the Wasserstein +distance between the model and the observed :math:`\mu^t` with respect to +the vector. This leads to the following optimization problem: + +.. math:: + \min_{\mathbf{w}\in\Delta_K} \quad W \left(\mu^t,\sum_{k=1}^K w_k\mu^s_k\right) + +This minimization is done in this example with a simple projected gradient +descent in PyTorch. We use the automatic backend of POT that allows us to +compute the Wasserstein distance with :any:`ot.emd2` with +differentiable losses. + +""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot +import torch + + +############################################################################## +# Generate data +# ------------- + +#%% Data + +nt = 100 +nt1 = 10 # + +ns1 = 50 +ns = 2 * ns1 + +rng = np.random.RandomState(2) + +xt = rng.randn(nt, 2) * 0.2 +xt[:nt1, 0] += 1 +xt[nt1:, 1] += 1 + + +xs1 = rng.randn(ns1, 2) * 0.2 +xs1[:, 0] += 1 +xs2 = rng.randn(ns1, 2) * 0.2 +xs2[:, 1] += 1 + +xs = np.concatenate((xs1, xs2)) + +# Sample reweighting matrix H +H = np.zeros((ns, 2)) +H[:ns1, 0] = 1 / ns1 +H[ns1:, 1] = 1 / ns1 +# each columns sums to 1 and has weights only for samples form the +# corresponding source distribution + +M = ot.dist(xs, xt) + +############################################################################## +# Plot data +# --------- + +#%% plot the distributions + +pl.figure(1) +pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5) +pl.scatter(xs1[:, 0], xs1[:, 1], label='Source $\mu^s_1$', alpha=0.5) +pl.scatter(xs2[:, 0], xs2[:, 1], label='Source $\mu^s_2$', alpha=0.5) +pl.title('Sources and Target distributions') +pl.legend() + + +############################################################################## +# Optimization of the model wrt the Wasserstein distance +# ------------------------------------------------------ + + +#%% Weights optimization with gradient descent + +# convert numpy arrays to torch tensors +H2 = torch.tensor(H) +M2 = torch.tensor(M) + +# weights for the source distributions +w = torch.tensor(ot.unif(2), requires_grad=True) + +# uniform weights for target +b = torch.tensor(ot.unif(nt)) + +lr = 2e-3 # learning rate +niter = 500 # number of iterations +losses = [] # loss along the iterations + +# loss for the minimal Wasserstein estimator + + +def get_loss(w): + a = torch.mv(H2, w) # distribution reweighting + return ot.emd2(a, b, M2) # squared Wasserstein 2 + + +for i in range(niter): + + loss = get_loss(w) + losses.append(float(loss)) + + loss.backward() + + with torch.no_grad(): + w -= lr * w.grad # gradient step + w[:] = ot.utils.proj_simplex(w) # projection on the simplex + + w.grad.zero_() + + +############################################################################## +# Estimated weights and convergence of the objective +# --------------------------------------------------- + +we = w.detach().numpy() +print('Estimated mixture:', we) + +pl.figure(2) +pl.semilogy(losses) +pl.grid() +pl.title('Wasserstein distance') +pl.xlabel("Iterations") + +############################################################################## +# Ploting the reweighted source distribution +# ------------------------------------------ + +pl.figure(3) + +# compute source weights +ws = H.dot(we) + +pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5) +pl.scatter(xs[:, 0], xs[:, 1], color='C3', s=ws * 20 * ns, label='Weighted sources $\sum_{k} w_k\mu^s_k$', alpha=0.5) +pl.title('Target and reweighted source distributions') +pl.legend() |