From ac4cf442735ed4c0d5405ad861eddaa02afd4edd Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Mon, 11 Apr 2022 15:38:18 +0200 Subject: [MRG] MM algorithms for UOT (#362) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bugfix * update refs partial OT * fixes small typos in plot_partial_wass_and_gromov * fix small bugs in partial.py * update README * pep8 bugfix * modif doctest * fix bugtests * update on test_partial and test on the numerical precision on ot/partial * resolve merge pb * Delete partial.py * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update releases.md with new MM UOT algorithms Co-authored-by: RĂ©mi Flamary --- examples/unbalanced-partial/plot_unbalanced_OT.py | 116 ++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 examples/unbalanced-partial/plot_unbalanced_OT.py (limited to 'examples') diff --git a/examples/unbalanced-partial/plot_unbalanced_OT.py b/examples/unbalanced-partial/plot_unbalanced_OT.py new file mode 100644 index 0000000..03487e7 --- /dev/null +++ b/examples/unbalanced-partial/plot_unbalanced_OT.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +""" +============================================================== +2D examples of exact and entropic unbalanced optimal transport +============================================================== +This example is designed to show how to compute unbalanced and +partial OT in POT. + +UOT aims at solving the following optimization problem: + + .. math:: + W = \min_{\gamma} <\gamma, \mathbf{M}>_F + + \mathrm{reg}\cdot\Omega(\gamma) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + +where :math:`\mathrm{div}` is a divergence. +When using the entropic UOT, :math:`\mathrm{reg}>0` and :math:`\mathrm{div}` +should be the Kullback-Leibler divergence. +When solving exact UOT, :math:`\mathrm{reg}=0` and :math:`\mathrm{div}` +can be either the Kullback-Leibler or the quadratic divergence. +Using :math:`\ell_1` norm gives the so-called partial OT. +""" + +# Author: Laetitia Chapel +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot + +############################################################################## +# Generate data +# ------------- + +# %% parameters and data generation + +n = 40 # nb samples + +mu_s = np.array([-1, -1]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +np.random.seed(0) +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +n_noise = 10 + +xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0) +xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0) + +n = n + n_noise + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +# loss matrix +M = ot.dist(xs, xt) +M /= M.max() + + +############################################################################## +# Compute entropic kl-regularized UOT, kl- and l2-regularized UOT +# ----------- + +reg = 0.005 +reg_m_kl = 0.05 +reg_m_l2 = 5 +mass = 0.7 + +entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl) +kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl') +l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2') +partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass) + +############################################################################## +# Plot the results +# ---------------- + +pl.figure(2) +transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot] +title = ["partial OT \n m=" + str(mass), "$\ell_2$-UOT \n $\mathrm{reg_m}$=" + + str(reg_m_l2), "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl), + "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl)] + +for p in range(4): + pl.subplot(2, 4, p + 1) + P = transp[p] + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.3) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2) + pl.title(title[p]) + pl.yticks(()) + pl.xticks(()) + if p < 1: + pl.ylabel("mappings") + pl.subplot(2, 4, p + 5) + pl.imshow(P, cmap='jet') + pl.yticks(()) + pl.xticks(()) + if p < 1: + pl.ylabel("transport plans") +pl.show() -- cgit v1.2.3