summaryrefslogtreecommitdiff
path: root/examples/plot_OT_1D_smooth.py
diff options
context:
space:
mode:
authorTianlin Liu <tliu@jacobs-alumni.de>2023-04-25 12:14:29 +0200
committerGitHub <noreply@github.com>2023-04-25 12:14:29 +0200
commit42a62c123776e04ee805aefb9afd6d98abdcf192 (patch)
treed439a1478c2f148c89678adc07736834b41255d4 /examples/plot_OT_1D_smooth.py
parent03ca4ef659a037e400975e3b2116b637a2d94265 (diff)
[FEAT] add the sparsity-constrained optimal transport funtionality and example (#459)
* add sparsity-constrained ot funtionality and example * correct typos; add projection_sparse_simplex * add gradcheck; merge ot.sparse into ot.smooth. * reuse existing ot.smooth functions with a new 'sparsity_constrained' reg_type * address pep8 error * add backends for * update releases --------- Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'examples/plot_OT_1D_smooth.py')
-rw-r--r--examples/plot_OT_1D_smooth.py51
1 files changed, 19 insertions, 32 deletions
diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py
index 5415e4f..ff51b8a 100644
--- a/examples/plot_OT_1D_smooth.py
+++ b/examples/plot_OT_1D_smooth.py
@@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
"""
================================
-Smooth optimal transport example
+Smooth and sparse OT example
================================
-This example illustrates the computation of EMD, Sinkhorn and smooth OT plans
-and their visualization.
+This example illustrates the computation of
+Smooth and Sparse (KL an L2 reg.) OT and
+sparsity-constrained OT, together with their visualizations.
"""
@@ -58,32 +59,6 @@ pl.legend()
pl.figure(2, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
-##############################################################################
-# Solve EMD
-# ---------
-
-
-#%% EMD
-
-G0 = ot.emd(a, b, M)
-
-pl.figure(3, figsize=(5, 5))
-ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
-
-##############################################################################
-# Solve Sinkhorn
-# --------------
-
-
-#%% Sinkhorn
-
-lambd = 2e-3
-Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
-
-pl.figure(4, figsize=(5, 5))
-ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
-
-pl.show()
##############################################################################
# Solve Smooth OT
@@ -95,18 +70,30 @@ pl.show()
lambd = 2e-3
Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl')
-pl.figure(5, figsize=(5, 5))
+pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.')
pl.show()
-#%% Smooth OT with KL regularization
+#%% Smooth OT with squared l2 regularization
lambd = 1e-1
Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2')
-pl.figure(6, figsize=(5, 5))
+pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.')
pl.show()
+
+#%% Sparsity-constrained OT
+
+lambd = 1e-1
+
+max_nz = 2 # two non-zero entries are permitted per column of the OT plan
+Gsc = ot.smooth.smooth_ot_dual(
+ a, b, M, lambd, reg_type='sparsity_constrained', max_nz=max_nz)
+pl.figure(5, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gsc, 'Sparsity contrained OT matrix; k=2.')
+
+pl.show()