summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorLaetitia Chapel <laetitia.chapel@univ-ubs.fr>2022-04-11 15:38:18 +0200
committerGitHub <noreply@github.com>2022-04-11 15:38:18 +0200
commitac4cf442735ed4c0d5405ad861eddaa02afd4edd (patch)
tree6f0bf54ca7452621bc55548f2a2a2615b8975b54 /examples
parent0b223ff883fd73601984a92c31cb70d4aded16e8 (diff)
[MRG] MM algorithms for UOT (#362)
* bugfix * update refs partial OT * fixes small typos in plot_partial_wass_and_gromov * fix small bugs in partial.py * update README * pep8 bugfix * modif doctest * fix bugtests * update on test_partial and test on the numerical precision on ot/partial * resolve merge pb * Delete partial.py * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update releases.md with new MM UOT algorithms Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'examples')
-rw-r--r--examples/unbalanced-partial/plot_unbalanced_OT.py116
1 files changed, 116 insertions, 0 deletions
diff --git a/examples/unbalanced-partial/plot_unbalanced_OT.py b/examples/unbalanced-partial/plot_unbalanced_OT.py
new file mode 100644
index 0000000..03487e7
--- /dev/null
+++ b/examples/unbalanced-partial/plot_unbalanced_OT.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+"""
+==============================================================
+2D examples of exact and entropic unbalanced optimal transport
+==============================================================
+This example is designed to show how to compute unbalanced and
+partial OT in POT.
+
+UOT aims at solving the following optimization problem:
+
+ .. math::
+ W = \min_{\gamma} <\gamma, \mathbf{M}>_F +
+ \mathrm{reg}\cdot\Omega(\gamma) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+
+ s.t.
+ \gamma \geq 0
+
+where :math:`\mathrm{div}` is a divergence.
+When using the entropic UOT, :math:`\mathrm{reg}>0` and :math:`\mathrm{div}`
+should be the Kullback-Leibler divergence.
+When solving exact UOT, :math:`\mathrm{reg}=0` and :math:`\mathrm{div}`
+can be either the Kullback-Leibler or the quadratic divergence.
+Using :math:`\ell_1` norm gives the so-called partial OT.
+"""
+
+# Author: Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+# %% parameters and data generation
+
+n = 40 # nb samples
+
+mu_s = np.array([-1, -1])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+np.random.seed(0)
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+n_noise = 10
+
+xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0)
+xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0)
+
+n = n + n_noise
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+# loss matrix
+M = ot.dist(xs, xt)
+M /= M.max()
+
+
+##############################################################################
+# Compute entropic kl-regularized UOT, kl- and l2-regularized UOT
+# -----------
+
+reg = 0.005
+reg_m_kl = 0.05
+reg_m_l2 = 5
+mass = 0.7
+
+entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl)
+kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl')
+l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2')
+partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass)
+
+##############################################################################
+# Plot the results
+# ----------------
+
+pl.figure(2)
+transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot]
+title = ["partial OT \n m=" + str(mass), "$\ell_2$-UOT \n $\mathrm{reg_m}$=" +
+ str(reg_m_l2), "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl),
+ "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl)]
+
+for p in range(4):
+ pl.subplot(2, 4, p + 1)
+ P = transp[p]
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.3)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2)
+ pl.title(title[p])
+ pl.yticks(())
+ pl.xticks(())
+ if p < 1:
+ pl.ylabel("mappings")
+ pl.subplot(2, 4, p + 5)
+ pl.imshow(P, cmap='jet')
+ pl.yticks(())
+ pl.xticks(())
+ if p < 1:
+ pl.ylabel("transport plans")
+pl.show()