summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-06-01 10:10:54 +0200
committerGitHub <noreply@github.com>2021-06-01 10:10:54 +0200
commit184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch)
tree483a7274c91030fd644de49b03a5fad04af9deba /examples
parent1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff)
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty <ncourty@irisa.fr> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'examples')
-rw-r--r--examples/README.txt2
-rw-r--r--examples/backends/README.txt4
-rw-r--r--examples/backends/plot_unmix_optim_torch.py161
3 files changed, 166 insertions, 1 deletions
diff --git a/examples/README.txt b/examples/README.txt
index 69a9f84..b48487f 100644
--- a/examples/README.txt
+++ b/examples/README.txt
@@ -1,7 +1,7 @@
Examples gallery
================
-This is a gallery of all the POT example files.
+This is a gallery of all the POT example files.
OT and regularized OT
diff --git a/examples/backends/README.txt b/examples/backends/README.txt
new file mode 100644
index 0000000..3ee0e27
--- /dev/null
+++ b/examples/backends/README.txt
@@ -0,0 +1,4 @@
+
+
+POT backend examples
+-------------------- \ No newline at end of file
diff --git a/examples/backends/plot_unmix_optim_torch.py b/examples/backends/plot_unmix_optim_torch.py
new file mode 100644
index 0000000..9ae66e9
--- /dev/null
+++ b/examples/backends/plot_unmix_optim_torch.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+r"""
+=================================
+Wasserstein unmixing with PyTorch
+=================================
+
+In this example we estimate mixing parameters from distributions that minimize
+the Wasserstein distance. In other words we suppose that a target
+distribution :math:`\mu^t` can be expressed as a weighted sum of source
+distributions :math:`\mu^s_k` with the following model:
+
+.. math::
+ \mu^t = \sum_{k=1}^K w_k\mu^s_k
+
+where :math:`\mathbf{w}` is a vector of size :math:`K` and belongs in the
+distribution simplex :math:`\Delta_K`.
+
+In order to estimate this weight vector we propose to optimize the Wasserstein
+distance between the model and the observed :math:`\mu^t` with respect to
+the vector. This leads to the following optimization problem:
+
+.. math::
+ \min_{\mathbf{w}\in\Delta_K} \quad W \left(\mu^t,\sum_{k=1}^K w_k\mu^s_k\right)
+
+This minimization is done in this example with a simple projected gradient
+descent in PyTorch. We use the automatic backend of POT that allows us to
+compute the Wasserstein distance with :any:`ot.emd2` with
+differentiable losses.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import torch
+
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% Data
+
+nt = 100
+nt1 = 10 #
+
+ns1 = 50
+ns = 2 * ns1
+
+rng = np.random.RandomState(2)
+
+xt = rng.randn(nt, 2) * 0.2
+xt[:nt1, 0] += 1
+xt[nt1:, 1] += 1
+
+
+xs1 = rng.randn(ns1, 2) * 0.2
+xs1[:, 0] += 1
+xs2 = rng.randn(ns1, 2) * 0.2
+xs2[:, 1] += 1
+
+xs = np.concatenate((xs1, xs2))
+
+# Sample reweighting matrix H
+H = np.zeros((ns, 2))
+H[:ns1, 0] = 1 / ns1
+H[ns1:, 1] = 1 / ns1
+# each columns sums to 1 and has weights only for samples form the
+# corresponding source distribution
+
+M = ot.dist(xs, xt)
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.figure(1)
+pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5)
+pl.scatter(xs1[:, 0], xs1[:, 1], label='Source $\mu^s_1$', alpha=0.5)
+pl.scatter(xs2[:, 0], xs2[:, 1], label='Source $\mu^s_2$', alpha=0.5)
+pl.title('Sources and Target distributions')
+pl.legend()
+
+
+##############################################################################
+# Optimization of the model wrt the Wasserstein distance
+# ------------------------------------------------------
+
+
+#%% Weights optimization with gradient descent
+
+# convert numpy arrays to torch tensors
+H2 = torch.tensor(H)
+M2 = torch.tensor(M)
+
+# weights for the source distributions
+w = torch.tensor(ot.unif(2), requires_grad=True)
+
+# uniform weights for target
+b = torch.tensor(ot.unif(nt))
+
+lr = 2e-3 # learning rate
+niter = 500 # number of iterations
+losses = [] # loss along the iterations
+
+# loss for the minimal Wasserstein estimator
+
+
+def get_loss(w):
+ a = torch.mv(H2, w) # distribution reweighting
+ return ot.emd2(a, b, M2) # squared Wasserstein 2
+
+
+for i in range(niter):
+
+ loss = get_loss(w)
+ losses.append(float(loss))
+
+ loss.backward()
+
+ with torch.no_grad():
+ w -= lr * w.grad # gradient step
+ w[:] = ot.utils.proj_simplex(w) # projection on the simplex
+
+ w.grad.zero_()
+
+
+##############################################################################
+# Estimated weights and convergence of the objective
+# ---------------------------------------------------
+
+we = w.detach().numpy()
+print('Estimated mixture:', we)
+
+pl.figure(2)
+pl.semilogy(losses)
+pl.grid()
+pl.title('Wasserstein distance')
+pl.xlabel("Iterations")
+
+##############################################################################
+# Ploting the reweighted source distribution
+# ------------------------------------------
+
+pl.figure(3)
+
+# compute source weights
+ws = H.dot(we)
+
+pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5)
+pl.scatter(xs[:, 0], xs[:, 1], color='C3', s=ws * 20 * ns, label='Weighted sources $\sum_{k} w_k\mu^s_k$', alpha=0.5)
+pl.title('Target and reweighted source distributions')
+pl.legend()