diff options
author | Gard Spreemann <gspr@nonempty.org> | 2020-07-09 08:50:11 +0200 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2020-07-09 08:50:11 +0200 |
commit | 616c6039614e96b6cfd88eb7db25ce11c7302c30 (patch) | |
tree | a745ec24604f660b256dbaee214affa497c9c21e /examples/plot_stochastic.py | |
parent | d62f05daf348b4e554056f298c66cbd64f5e3c6e (diff) | |
parent | a16b9471d7114ec08977479b7249efe747702b97 (diff) |
Merge branch 'dfsg/latest' into debian/sid
Diffstat (limited to 'examples/plot_stochastic.py')
-rw-r--r-- | examples/plot_stochastic.py | 101 |
1 files changed, 41 insertions, 60 deletions
diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py index 742f8d9..3a1ef31 100644 --- a/examples/plot_stochastic.py +++ b/examples/plot_stochastic.py @@ -1,10 +1,18 @@ """ -========================== +=================== Stochastic examples -========================== +=================== This example is designed to show how to use the stochatic optimization -algorithms for descrete and semicontinous measures from the POT library. +algorithms for discrete and semi-continuous measures from the POT library. + +[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. +Stochastic Optimization for Large-scale Optimal Transport. +Advances in Neural Information Processing Systems (2016). + +[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. & +Blondel, M. Large-scale Optimal Transport and Mapping Estimation. +International Conference on Learning Representation (2018) """ @@ -19,16 +27,14 @@ import ot.plot ############################################################################# -# COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM -############################################################################# -############################################################################# -# DISCRETE CASE: +# Compute the Transportation Matrix for the Semi-Dual Problem +# ----------------------------------------------------------- # -# Sample two discrete measures for the discrete case -# --------------------------------------------- +# Discrete case +# ````````````` # -# Define 2 discrete measures a and b, the points where are defined the source -# and the target measures and finally the cost matrix c. +# Sample two discrete measures for the discrete case and compute their cost +# matrix c. n_source = 7 n_target = 4 @@ -44,12 +50,7 @@ Y_target = rng.randn(n_target, 2) M = ot.dist(X_source, Y_target) ############################################################################# -# # Call the "SAG" method to find the transportation matrix in the discrete case -# --------------------------------------------- -# -# Define the method "SAG", call ot.solve_semi_dual_entropic and plot the -# results. method = "SAG" sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, @@ -57,14 +58,12 @@ sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, print(sag_pi) ############################################################################# -# SEMICONTINOUS CASE: +# Semi-Continuous Case +# ```````````````````` # # Sample one general measure a, one discrete measures b for the semicontinous -# case -# --------------------------------------------- -# -# Define one general measure a, one discrete measures b, the points where -# are defined the source and the target measures and finally the cost matrix c. +# case, the points where source and target measures are defined and compute the +# cost matrix. n_source = 7 n_target = 4 @@ -81,13 +80,8 @@ Y_target = rng.randn(n_target, 2) M = ot.dist(X_source, Y_target) ############################################################################# -# # Call the "ASGD" method to find the transportation matrix in the semicontinous -# case -# --------------------------------------------- -# -# Define the method "ASGD", call ot.solve_semi_dual_entropic and plot the -# results. +# case. method = "ASGD" asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, @@ -96,23 +90,17 @@ print(log_asgd['alpha'], log_asgd['beta']) print(asgd_pi) ############################################################################# -# # Compare the results with the Sinkhorn algorithm -# --------------------------------------------- -# -# Call the Sinkhorn algorithm from POT sinkhorn_pi = ot.sinkhorn(a, b, M, reg) print(sinkhorn_pi) ############################################################################## -# PLOT TRANSPORTATION MATRIX -############################################################################## - -############################################################################## -# Plot SAG results -# ---------------- +# Plot Transportation Matrices +# ```````````````````````````` +# +# For SAG pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG') @@ -120,8 +108,7 @@ pl.show() ############################################################################## -# Plot ASGD results -# ----------------- +# For ASGD pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD') @@ -129,8 +116,7 @@ pl.show() ############################################################################## -# Plot Sinkhorn results -# --------------------- +# For Sinkhorn pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') @@ -138,17 +124,14 @@ pl.show() ############################################################################# -# COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM -############################################################################# -############################################################################# -# SEMICONTINOUS CASE: +# Compute the Transportation Matrix for the Dual Problem +# ------------------------------------------------------ # -# Sample one general measure a, one discrete measures b for the semicontinous -# case -# --------------------------------------------- +# Semi-continuous case +# ```````````````````` # -# Define one general measure a, one discrete measures b, the points where -# are defined the source and the target measures and finally the cost matrix c. +# Sample one general measure a, one discrete measures b for the semi-continuous +# case and compute the cost matrix c. n_source = 7 n_target = 4 @@ -169,10 +152,7 @@ M = ot.dist(X_source, Y_target) ############################################################################# # # Call the "SGD" dual method to find the transportation matrix in the -# semicontinous case -# --------------------------------------------- -# -# Call ot.solve_dual_entropic and plot the results. +# semi-continuous case sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, @@ -183,7 +163,7 @@ print(sgd_dual_pi) ############################################################################# # # Compare the results with the Sinkhorn algorithm -# --------------------------------------------- +# ``````````````````````````````````````````````` # # Call the Sinkhorn algorithm from POT @@ -191,8 +171,10 @@ sinkhorn_pi = ot.sinkhorn(a, b, M, reg) print(sinkhorn_pi) ############################################################################## -# Plot SGD results -# ----------------- +# Plot Transportation Matrices +# ```````````````````````````` +# +# For SGD pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD') @@ -200,8 +182,7 @@ pl.show() ############################################################################## -# Plot Sinkhorn results -# --------------------- +# For Sinkhorn pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') |