From ad02112d4288f3efdd5bc6fc6e45444313bba871 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Tue, 5 Apr 2022 11:57:10 +0200 Subject: [MRG] Update examples in the doc (#359) * add transparent color logo * add transparent color logo * move screenkhorn * move stochastic and install ffmpeg on circleci * try something * add sudo * install ffmpeg before python * cleanup examples * test svg scrapper * add animation for reg path * better example OT sivergence * update ttles and add plots * update free support * proper figure indexes * have less frame sin animation * update readme and release file * add tests for python 3.10 --- examples/unbalanced-partial/plot_UOT_1D.py | 17 +++++- examples/unbalanced-partial/plot_regpath.py | 88 ++++++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 3 deletions(-) (limited to 'examples/unbalanced-partial') diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index 183849c..06dd02d 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -12,6 +12,8 @@ using a Kullback-Leibler relaxation. # # License: MIT License +# sphinx_gallery_thumbnail_number = 4 + import numpy as np import matplotlib.pylab as pl import ot @@ -69,7 +71,20 @@ epsilon = 0.1 # entropy parameter alpha = 1. # Unbalanced KL relaxation parameter Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) -pl.figure(4, figsize=(5, 5)) +pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') pl.show() + + +# %% +# plot the transported mass +# ------------------------- + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.fill(x, Gs.sum(1), 'b', alpha=0.5, label='Transported source') +pl.fill(x, Gs.sum(0), 'r', alpha=0.5, label='Transported target') +pl.legend(loc='upper right') +pl.title('Distributions and transported mass for UOT') diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py index 4a51c2d..782e8c2 100644 --- a/examples/unbalanced-partial/plot_regpath.py +++ b/examples/unbalanced-partial/plot_regpath.py @@ -15,11 +15,12 @@ penalized linear regression. # Author: Haoran Wu # License: MIT License +# sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl import ot - +import matplotlib.animation as animation ############################################################################## # Generate data # ------------- @@ -72,6 +73,9 @@ t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma, ############################################################################## # Plot the regularization path # ---------------- +# +# The OT plan is ploted as a function of $\gamma$ that is the inverse of the +# weight on the marginal relaxations. #%% fully relaxed l2-penalized UOT @@ -103,13 +107,53 @@ for p in range(4): pl.show() +# %% +# Animation of the regpath for UOT l2 +# ------------------------ + +nv = 100 +g_list_v = np.logspace(-.5, -2.5, nv) + +pl.figure(3) + + +def _update_plot(iv): + pl.clf() + tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list, + t_list) + P = tp.reshape((n, n)) + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.5) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4, + label='Re-weighted source', alpha=1) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4, + label='Re-weighted target', alpha=1) + pl.plot([], [], color='C2', alpha=0.8, label='OT plan') + pl.title(r'$\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]), + fontsize=11) + return 1 + + +i = 0 +_update_plot(i) + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000) + + ############################################################################## # Plot the semi-relaxed regularization path # ------------------- #%% semi-relaxed l2-penalized UOT -pl.figure(3) +pl.figure(4) selected_gamma = [10, 1, 1e-1, 1e-2] for p in range(4): tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2, @@ -133,3 +177,43 @@ for p in range(4): if p < 2: pl.xticks(()) pl.show() + + +# %% +# Animation of the regpath for semi-relaxed UOT l2 +# ------------------------ + +nv = 100 +g_list_v = np.logspace(2.5, -2, nv) + +pl.figure(5) + + +def _update_plot(iv): + pl.clf() + tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list2, + t_list2) + P = tp.reshape((n, n)) + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.5) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4, + label='Re-weighted source', alpha=1) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4, + label='Re-weighted target', alpha=1) + pl.plot([], [], color='C2', alpha=0.8, label='OT plan') + pl.title(r'Semi-relaxed $\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]), + fontsize=11) + return 1 + + +i = 0 +_update_plot(i) + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000) -- cgit v1.2.3