summaryrefslogtreecommitdiff
path: root/examples/unbalanced-partial/plot_regpath.py
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
committerGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
commit35bd2c98b642df78638d7d733bc1a89d873db1de (patch)
tree6bc637624004713808d3097b95acdccbb9608e52 /examples/unbalanced-partial/plot_regpath.py
parentc4753bd3f74139af8380127b66b484bc09b50661 (diff)
parenteccb1386eea52b94b82456d126bd20cbe3198e05 (diff)
Merge tag '0.8.2' into dfsg/latest
Diffstat (limited to 'examples/unbalanced-partial/plot_regpath.py')
-rw-r--r--examples/unbalanced-partial/plot_regpath.py88
1 files changed, 86 insertions, 2 deletions
diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py
index 4a51c2d..782e8c2 100644
--- a/examples/unbalanced-partial/plot_regpath.py
+++ b/examples/unbalanced-partial/plot_regpath.py
@@ -15,11 +15,12 @@ penalized linear regression.
# Author: Haoran Wu <haoran.wu@univ-ubs.fr>
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
import ot
-
+import matplotlib.animation as animation
##############################################################################
# Generate data
# -------------
@@ -72,6 +73,9 @@ t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
##############################################################################
# Plot the regularization path
# ----------------
+#
+# The OT plan is ploted as a function of $\gamma$ that is the inverse of the
+# weight on the marginal relaxations.
#%% fully relaxed l2-penalized UOT
@@ -103,13 +107,53 @@ for p in range(4):
pl.show()
+# %%
+# Animation of the regpath for UOT l2
+# ------------------------
+
+nv = 100
+g_list_v = np.logspace(-.5, -2.5, nv)
+
+pl.figure(3)
+
+
+def _update_plot(iv):
+ pl.clf()
+ tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list,
+ t_list)
+ P = tp.reshape((n, n))
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.5)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4,
+ label='Re-weighted source', alpha=1)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4,
+ label='Re-weighted target', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'$\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]),
+ fontsize=11)
+ return 1
+
+
+i = 0
+_update_plot(i)
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)
+
+
##############################################################################
# Plot the semi-relaxed regularization path
# -------------------
#%% semi-relaxed l2-penalized UOT
-pl.figure(3)
+pl.figure(4)
selected_gamma = [10, 1, 1e-1, 1e-2]
for p in range(4):
tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2,
@@ -133,3 +177,43 @@ for p in range(4):
if p < 2:
pl.xticks(())
pl.show()
+
+
+# %%
+# Animation of the regpath for semi-relaxed UOT l2
+# ------------------------
+
+nv = 100
+g_list_v = np.logspace(2.5, -2, nv)
+
+pl.figure(5)
+
+
+def _update_plot(iv):
+ pl.clf()
+ tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list2,
+ t_list2)
+ P = tp.reshape((n, n))
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.5)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4,
+ label='Re-weighted source', alpha=1)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4,
+ label='Re-weighted target', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'Semi-relaxed $\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]),
+ fontsize=11)
+ return 1
+
+
+i = 0
+_update_plot(i)
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)