summaryrefslogtreecommitdiff
path: root/examples/plot_stochastic.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/plot_stochastic.py')
-rw-r--r--examples/plot_stochastic.py101
1 files changed, 41 insertions, 60 deletions
diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py
index 742f8d9..3a1ef31 100644
--- a/examples/plot_stochastic.py
+++ b/examples/plot_stochastic.py
@@ -1,10 +1,18 @@
"""
-==========================
+===================
Stochastic examples
-==========================
+===================
This example is designed to show how to use the stochatic optimization
-algorithms for descrete and semicontinous measures from the POT library.
+algorithms for discrete and semi-continuous measures from the POT library.
+
+[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F.
+Stochastic Optimization for Large-scale Optimal Transport.
+Advances in Neural Information Processing Systems (2016).
+
+[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. &
+Blondel, M. Large-scale Optimal Transport and Mapping Estimation.
+International Conference on Learning Representation (2018)
"""
@@ -19,16 +27,14 @@ import ot.plot
#############################################################################
-# COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM
-#############################################################################
-#############################################################################
-# DISCRETE CASE:
+# Compute the Transportation Matrix for the Semi-Dual Problem
+# -----------------------------------------------------------
#
-# Sample two discrete measures for the discrete case
-# ---------------------------------------------
+# Discrete case
+# `````````````
#
-# Define 2 discrete measures a and b, the points where are defined the source
-# and the target measures and finally the cost matrix c.
+# Sample two discrete measures for the discrete case and compute their cost
+# matrix c.
n_source = 7
n_target = 4
@@ -44,12 +50,7 @@ Y_target = rng.randn(n_target, 2)
M = ot.dist(X_source, Y_target)
#############################################################################
-#
# Call the "SAG" method to find the transportation matrix in the discrete case
-# ---------------------------------------------
-#
-# Define the method "SAG", call ot.solve_semi_dual_entropic and plot the
-# results.
method = "SAG"
sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
@@ -57,14 +58,12 @@ sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
print(sag_pi)
#############################################################################
-# SEMICONTINOUS CASE:
+# Semi-Continuous Case
+# ````````````````````
#
# Sample one general measure a, one discrete measures b for the semicontinous
-# case
-# ---------------------------------------------
-#
-# Define one general measure a, one discrete measures b, the points where
-# are defined the source and the target measures and finally the cost matrix c.
+# case, the points where source and target measures are defined and compute the
+# cost matrix.
n_source = 7
n_target = 4
@@ -81,13 +80,8 @@ Y_target = rng.randn(n_target, 2)
M = ot.dist(X_source, Y_target)
#############################################################################
-#
# Call the "ASGD" method to find the transportation matrix in the semicontinous
-# case
-# ---------------------------------------------
-#
-# Define the method "ASGD", call ot.solve_semi_dual_entropic and plot the
-# results.
+# case.
method = "ASGD"
asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
@@ -96,23 +90,17 @@ print(log_asgd['alpha'], log_asgd['beta'])
print(asgd_pi)
#############################################################################
-#
# Compare the results with the Sinkhorn algorithm
-# ---------------------------------------------
-#
-# Call the Sinkhorn algorithm from POT
sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
print(sinkhorn_pi)
##############################################################################
-# PLOT TRANSPORTATION MATRIX
-##############################################################################
-
-##############################################################################
-# Plot SAG results
-# ----------------
+# Plot Transportation Matrices
+# ````````````````````````````
+#
+# For SAG
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')
@@ -120,8 +108,7 @@ pl.show()
##############################################################################
-# Plot ASGD results
-# -----------------
+# For ASGD
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')
@@ -129,8 +116,7 @@ pl.show()
##############################################################################
-# Plot Sinkhorn results
-# ---------------------
+# For Sinkhorn
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
@@ -138,17 +124,14 @@ pl.show()
#############################################################################
-# COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM
-#############################################################################
-#############################################################################
-# SEMICONTINOUS CASE:
+# Compute the Transportation Matrix for the Dual Problem
+# ------------------------------------------------------
#
-# Sample one general measure a, one discrete measures b for the semicontinous
-# case
-# ---------------------------------------------
+# Semi-continuous case
+# ````````````````````
#
-# Define one general measure a, one discrete measures b, the points where
-# are defined the source and the target measures and finally the cost matrix c.
+# Sample one general measure a, one discrete measures b for the semi-continuous
+# case and compute the cost matrix c.
n_source = 7
n_target = 4
@@ -169,10 +152,7 @@ M = ot.dist(X_source, Y_target)
#############################################################################
#
# Call the "SGD" dual method to find the transportation matrix in the
-# semicontinous case
-# ---------------------------------------------
-#
-# Call ot.solve_dual_entropic and plot the results.
+# semi-continuous case
sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg,
batch_size, numItermax,
@@ -183,7 +163,7 @@ print(sgd_dual_pi)
#############################################################################
#
# Compare the results with the Sinkhorn algorithm
-# ---------------------------------------------
+# ```````````````````````````````````````````````
#
# Call the Sinkhorn algorithm from POT
@@ -191,8 +171,10 @@ sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
print(sinkhorn_pi)
##############################################################################
-# Plot SGD results
-# -----------------
+# Plot Transportation Matrices
+# ````````````````````````````
+#
+# For SGD
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')
@@ -200,8 +182,7 @@ pl.show()
##############################################################################
-# Plot Sinkhorn results
-# ---------------------
+# For Sinkhorn
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')