summaryrefslogtreecommitdiff
path: root/examples/unbalanced-partial
diff options
context:
space:
mode:
Diffstat (limited to 'examples/unbalanced-partial')
-rw-r--r--examples/unbalanced-partial/plot_UOT_1D.py17
-rw-r--r--examples/unbalanced-partial/plot_regpath.py88
-rw-r--r--examples/unbalanced-partial/plot_unbalanced_OT.py116
3 files changed, 218 insertions, 3 deletions
diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py
index 183849c..06dd02d 100644
--- a/examples/unbalanced-partial/plot_UOT_1D.py
+++ b/examples/unbalanced-partial/plot_UOT_1D.py
@@ -12,6 +12,8 @@ using a Kullback-Leibler relaxation.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 4
+
import numpy as np
import matplotlib.pylab as pl
import ot
@@ -69,7 +71,20 @@ epsilon = 0.1 # entropy parameter
alpha = 1. # Unbalanced KL relaxation parameter
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
-pl.figure(4, figsize=(5, 5))
+pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')
pl.show()
+
+
+# %%
+# plot the transported mass
+# -------------------------
+
+pl.figure(4, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.fill(x, Gs.sum(1), 'b', alpha=0.5, label='Transported source')
+pl.fill(x, Gs.sum(0), 'r', alpha=0.5, label='Transported target')
+pl.legend(loc='upper right')
+pl.title('Distributions and transported mass for UOT')
diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py
index 4a51c2d..782e8c2 100644
--- a/examples/unbalanced-partial/plot_regpath.py
+++ b/examples/unbalanced-partial/plot_regpath.py
@@ -15,11 +15,12 @@ penalized linear regression.
# Author: Haoran Wu <haoran.wu@univ-ubs.fr>
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
import ot
-
+import matplotlib.animation as animation
##############################################################################
# Generate data
# -------------
@@ -72,6 +73,9 @@ t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
##############################################################################
# Plot the regularization path
# ----------------
+#
+# The OT plan is ploted as a function of $\gamma$ that is the inverse of the
+# weight on the marginal relaxations.
#%% fully relaxed l2-penalized UOT
@@ -103,13 +107,53 @@ for p in range(4):
pl.show()
+# %%
+# Animation of the regpath for UOT l2
+# ------------------------
+
+nv = 100
+g_list_v = np.logspace(-.5, -2.5, nv)
+
+pl.figure(3)
+
+
+def _update_plot(iv):
+ pl.clf()
+ tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list,
+ t_list)
+ P = tp.reshape((n, n))
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.5)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4,
+ label='Re-weighted source', alpha=1)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4,
+ label='Re-weighted target', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'$\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]),
+ fontsize=11)
+ return 1
+
+
+i = 0
+_update_plot(i)
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)
+
+
##############################################################################
# Plot the semi-relaxed regularization path
# -------------------
#%% semi-relaxed l2-penalized UOT
-pl.figure(3)
+pl.figure(4)
selected_gamma = [10, 1, 1e-1, 1e-2]
for p in range(4):
tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2,
@@ -133,3 +177,43 @@ for p in range(4):
if p < 2:
pl.xticks(())
pl.show()
+
+
+# %%
+# Animation of the regpath for semi-relaxed UOT l2
+# ------------------------
+
+nv = 100
+g_list_v = np.logspace(2.5, -2, nv)
+
+pl.figure(5)
+
+
+def _update_plot(iv):
+ pl.clf()
+ tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list2,
+ t_list2)
+ P = tp.reshape((n, n))
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.5)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4,
+ label='Re-weighted source', alpha=1)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4,
+ label='Re-weighted target', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'Semi-relaxed $\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]),
+ fontsize=11)
+ return 1
+
+
+i = 0
+_update_plot(i)
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)
diff --git a/examples/unbalanced-partial/plot_unbalanced_OT.py b/examples/unbalanced-partial/plot_unbalanced_OT.py
new file mode 100644
index 0000000..03487e7
--- /dev/null
+++ b/examples/unbalanced-partial/plot_unbalanced_OT.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+"""
+==============================================================
+2D examples of exact and entropic unbalanced optimal transport
+==============================================================
+This example is designed to show how to compute unbalanced and
+partial OT in POT.
+
+UOT aims at solving the following optimization problem:
+
+ .. math::
+ W = \min_{\gamma} <\gamma, \mathbf{M}>_F +
+ \mathrm{reg}\cdot\Omega(\gamma) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+
+ s.t.
+ \gamma \geq 0
+
+where :math:`\mathrm{div}` is a divergence.
+When using the entropic UOT, :math:`\mathrm{reg}>0` and :math:`\mathrm{div}`
+should be the Kullback-Leibler divergence.
+When solving exact UOT, :math:`\mathrm{reg}=0` and :math:`\mathrm{div}`
+can be either the Kullback-Leibler or the quadratic divergence.
+Using :math:`\ell_1` norm gives the so-called partial OT.
+"""
+
+# Author: Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+# %% parameters and data generation
+
+n = 40 # nb samples
+
+mu_s = np.array([-1, -1])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+np.random.seed(0)
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+n_noise = 10
+
+xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0)
+xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0)
+
+n = n + n_noise
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+# loss matrix
+M = ot.dist(xs, xt)
+M /= M.max()
+
+
+##############################################################################
+# Compute entropic kl-regularized UOT, kl- and l2-regularized UOT
+# -----------
+
+reg = 0.005
+reg_m_kl = 0.05
+reg_m_l2 = 5
+mass = 0.7
+
+entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl)
+kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl')
+l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2')
+partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass)
+
+##############################################################################
+# Plot the results
+# ----------------
+
+pl.figure(2)
+transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot]
+title = ["partial OT \n m=" + str(mass), "$\ell_2$-UOT \n $\mathrm{reg_m}$=" +
+ str(reg_m_l2), "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl),
+ "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl)]
+
+for p in range(4):
+ pl.subplot(2, 4, p + 1)
+ P = transp[p]
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.3)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2)
+ pl.title(title[p])
+ pl.yticks(())
+ pl.xticks(())
+ if p < 1:
+ pl.ylabel("mappings")
+ pl.subplot(2, 4, p + 5)
+ pl.imshow(P, cmap='jet')
+ pl.yticks(())
+ pl.xticks(())
+ if p < 1:
+ pl.ylabel("transport plans")
+pl.show()