diff options
Diffstat (limited to 'examples')
-rw-r--r-- | examples/plot_OT_1D_smooth.py | 51 |
1 files changed, 19 insertions, 32 deletions
diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py index 5415e4f..ff51b8a 100644 --- a/examples/plot_OT_1D_smooth.py +++ b/examples/plot_OT_1D_smooth.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- """ ================================ -Smooth optimal transport example +Smooth and sparse OT example ================================ -This example illustrates the computation of EMD, Sinkhorn and smooth OT plans -and their visualization. +This example illustrates the computation of +Smooth and Sparse (KL an L2 reg.) OT and +sparsity-constrained OT, together with their visualizations. """ @@ -58,32 +59,6 @@ pl.legend() pl.figure(2, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') -############################################################################## -# Solve EMD -# --------- - - -#%% EMD - -G0 = ot.emd(a, b, M) - -pl.figure(3, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') - -############################################################################## -# Solve Sinkhorn -# -------------- - - -#%% Sinkhorn - -lambd = 2e-3 -Gs = ot.sinkhorn(a, b, M, lambd, verbose=True) - -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn') - -pl.show() ############################################################################## # Solve Smooth OT @@ -95,18 +70,30 @@ pl.show() lambd = 2e-3 Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl') -pl.figure(5, figsize=(5, 5)) +pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.') pl.show() -#%% Smooth OT with KL regularization +#%% Smooth OT with squared l2 regularization lambd = 1e-1 Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2') -pl.figure(6, figsize=(5, 5)) +pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.') pl.show() + +#%% Sparsity-constrained OT + +lambd = 1e-1 + +max_nz = 2 # two non-zero entries are permitted per column of the OT plan +Gsc = ot.smooth.smooth_ot_dual( + a, b, M, lambd, reg_type='sparsity_constrained', max_nz=max_nz) +pl.figure(5, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gsc, 'Sparsity contrained OT matrix; k=2.') + +pl.show() |