diff options
Diffstat (limited to 'examples/unbalanced-partial')
-rw-r--r-- | examples/unbalanced-partial/plot_UOT_1D.py | 17 | ||||
-rw-r--r-- | examples/unbalanced-partial/plot_regpath.py | 88 | ||||
-rw-r--r-- | examples/unbalanced-partial/plot_unbalanced_OT.py | 116 |
3 files changed, 218 insertions, 3 deletions
diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index 183849c..06dd02d 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -12,6 +12,8 @@ using a Kullback-Leibler relaxation. # # License: MIT License +# sphinx_gallery_thumbnail_number = 4 + import numpy as np import matplotlib.pylab as pl import ot @@ -69,7 +71,20 @@ epsilon = 0.1 # entropy parameter alpha = 1. # Unbalanced KL relaxation parameter Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) -pl.figure(4, figsize=(5, 5)) +pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') pl.show() + + +# %% +# plot the transported mass +# ------------------------- + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.fill(x, Gs.sum(1), 'b', alpha=0.5, label='Transported source') +pl.fill(x, Gs.sum(0), 'r', alpha=0.5, label='Transported target') +pl.legend(loc='upper right') +pl.title('Distributions and transported mass for UOT') diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py index 4a51c2d..782e8c2 100644 --- a/examples/unbalanced-partial/plot_regpath.py +++ b/examples/unbalanced-partial/plot_regpath.py @@ -15,11 +15,12 @@ penalized linear regression. # Author: Haoran Wu <haoran.wu@univ-ubs.fr> # License: MIT License +# sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl import ot - +import matplotlib.animation as animation ############################################################################## # Generate data # ------------- @@ -72,6 +73,9 @@ t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma, ############################################################################## # Plot the regularization path # ---------------- +# +# The OT plan is ploted as a function of $\gamma$ that is the inverse of the +# weight on the marginal relaxations. #%% fully relaxed l2-penalized UOT @@ -103,13 +107,53 @@ for p in range(4): pl.show() +# %% +# Animation of the regpath for UOT l2 +# ------------------------ + +nv = 100 +g_list_v = np.logspace(-.5, -2.5, nv) + +pl.figure(3) + + +def _update_plot(iv): + pl.clf() + tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list, + t_list) + P = tp.reshape((n, n)) + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.5) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4, + label='Re-weighted source', alpha=1) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4, + label='Re-weighted target', alpha=1) + pl.plot([], [], color='C2', alpha=0.8, label='OT plan') + pl.title(r'$\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]), + fontsize=11) + return 1 + + +i = 0 +_update_plot(i) + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000) + + ############################################################################## # Plot the semi-relaxed regularization path # ------------------- #%% semi-relaxed l2-penalized UOT -pl.figure(3) +pl.figure(4) selected_gamma = [10, 1, 1e-1, 1e-2] for p in range(4): tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2, @@ -133,3 +177,43 @@ for p in range(4): if p < 2: pl.xticks(()) pl.show() + + +# %% +# Animation of the regpath for semi-relaxed UOT l2 +# ------------------------ + +nv = 100 +g_list_v = np.logspace(2.5, -2, nv) + +pl.figure(5) + + +def _update_plot(iv): + pl.clf() + tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list2, + t_list2) + P = tp.reshape((n, n)) + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.5) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4, + label='Re-weighted source', alpha=1) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4, + label='Re-weighted target', alpha=1) + pl.plot([], [], color='C2', alpha=0.8, label='OT plan') + pl.title(r'Semi-relaxed $\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]), + fontsize=11) + return 1 + + +i = 0 +_update_plot(i) + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000) diff --git a/examples/unbalanced-partial/plot_unbalanced_OT.py b/examples/unbalanced-partial/plot_unbalanced_OT.py new file mode 100644 index 0000000..03487e7 --- /dev/null +++ b/examples/unbalanced-partial/plot_unbalanced_OT.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +""" +============================================================== +2D examples of exact and entropic unbalanced optimal transport +============================================================== +This example is designed to show how to compute unbalanced and +partial OT in POT. + +UOT aims at solving the following optimization problem: + + .. math:: + W = \min_{\gamma} <\gamma, \mathbf{M}>_F + + \mathrm{reg}\cdot\Omega(\gamma) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + +where :math:`\mathrm{div}` is a divergence. +When using the entropic UOT, :math:`\mathrm{reg}>0` and :math:`\mathrm{div}` +should be the Kullback-Leibler divergence. +When solving exact UOT, :math:`\mathrm{reg}=0` and :math:`\mathrm{div}` +can be either the Kullback-Leibler or the quadratic divergence. +Using :math:`\ell_1` norm gives the so-called partial OT. +""" + +# Author: Laetitia Chapel <laetitia.chapel@univ-ubs.fr> +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot + +############################################################################## +# Generate data +# ------------- + +# %% parameters and data generation + +n = 40 # nb samples + +mu_s = np.array([-1, -1]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +np.random.seed(0) +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +n_noise = 10 + +xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0) +xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0) + +n = n + n_noise + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +# loss matrix +M = ot.dist(xs, xt) +M /= M.max() + + +############################################################################## +# Compute entropic kl-regularized UOT, kl- and l2-regularized UOT +# ----------- + +reg = 0.005 +reg_m_kl = 0.05 +reg_m_l2 = 5 +mass = 0.7 + +entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl) +kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl') +l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2') +partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass) + +############################################################################## +# Plot the results +# ---------------- + +pl.figure(2) +transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot] +title = ["partial OT \n m=" + str(mass), "$\ell_2$-UOT \n $\mathrm{reg_m}$=" + + str(reg_m_l2), "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl), + "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl)] + +for p in range(4): + pl.subplot(2, 4, p + 1) + P = transp[p] + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.3) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2) + pl.title(title[p]) + pl.yticks(()) + pl.xticks(()) + if p < 1: + pl.ylabel("mappings") + pl.subplot(2, 4, p + 5) + pl.imshow(P, cmap='jet') + pl.yticks(()) + pl.xticks(()) + if p < 1: + pl.ylabel("transport plans") +pl.show() |