diff options
author | Kilian Fatras <kilianfatras@dhcp-206-12-53-99.eduroam.wireless.ubc.ca> | 2018-06-15 18:53:54 -0700 |
---|---|---|
committer | Kilian Fatras <kilianfatras@dhcp-206-12-53-99.eduroam.wireless.ubc.ca> | 2018-06-15 18:53:54 -0700 |
commit | c8eda449b2c6b39e9d57d1b5b2c39e43f2925892 (patch) | |
tree | 19e241788b920c8c5c181ca6562553a2e3bd7468 /examples/plot_stochastic.py | |
parent | 90efa5a8b189214d1aeb81920b2bb04ce0c261ca (diff) |
add problems solved in doc
Diffstat (limited to 'examples/plot_stochastic.py')
-rw-r--r-- | examples/plot_stochastic.py | 135 |
1 files changed, 135 insertions, 0 deletions
diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py new file mode 100644 index 0000000..9071ddb --- /dev/null +++ b/examples/plot_stochastic.py @@ -0,0 +1,135 @@ +""" +========================== +Stochastic examples +========================== + +This example is designed to show how to use the stochatic optimization +algorithms for descrete and semicontinous measures from the POT library. + +""" + +# Author: Kilian Fatras <kilian.fatras@gmail.com> +# +# License: MIT License + +import matplotlib.pylab as pl +import numpy as np +import ot + + +############################################################################# +# COMPUTE TRANSPORTATION MATRIX +############################################################################# + +############################################################################# +# DISCRETE CASE +# Sample two discrete measures for the discrete case +# --------------------------------------------- +# +# Define 2 discrete measures a and b, the points where are defined the source +# and the target measures and finally the cost matrix c. + +n_source = 7 +n_target = 4 +reg = 1 +numItermax = 10000 +lr = 0.1 + +a = ot.utils.unif(n_source) +b = ot.utils.unif(n_target) + +rng = np.random.RandomState(0) +X_source = rng.randn(n_source, 2) +Y_target = rng.randn(n_target, 2) +M = ot.dist(X_source, Y_target) + +############################################################################# +# +# Call the "SAG" method to find the transportation matrix in the discrete case +# --------------------------------------------- +# +# Define the method "SAG", call ot.transportation_matrix_entropic and plot the +# results. + +method = "SAG" +sag_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method, + numItermax, lr) +print(sag_pi) + +############################################################################# +# SEMICONTINOUS CASE +# Sample one general measure a, one discrete measures b for the semicontinous +# case +# --------------------------------------------- +# +# Define one general measure a, one discrete measures b, the points where +# are defined the source and the target measures and finally the cost matrix c. + +n_source = 7 +n_target = 4 +reg = 1 +numItermax = 500000 +lr = 1 + +a = ot.utils.unif(n_source) +b = ot.utils.unif(n_target) + +rng = np.random.RandomState(0) +X_source = rng.randn(n_source, 2) +Y_target = rng.randn(n_target, 2) +M = ot.dist(X_source, Y_target) + +############################################################################# +# +# Call the "ASGD" method to find the transportation matrix in the semicontinous +# case +# --------------------------------------------- +# +# Define the method "ASGD", call ot.transportation_matrix_entropic and plot the +# results. + +method = "ASGD" +asgd_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method, + numItermax, lr) +print(asgd_pi) + +############################################################################# +# +# Compare the results with the Sinkhorn algorithm +# --------------------------------------------- +# +# Call the Sinkhorn algorithm from POT + +sinkhorn_pi = ot.sinkhorn(a, b, M, 1) +print(sinkhorn_pi) + + +############################################################################## +# PLOT TRANSPORTATION MATRIX +############################################################################## + +############################################################################## +# Plot SAG results +# ---------------- + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sag_pi, 'OT matrix SAG') +pl.show() + + +############################################################################## +# Plot ASGD results +# ----------------- + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, asgd_pi, 'OT matrix ASGD') +pl.show() + + +############################################################################## +# Plot Sinkhorn results +# --------------------- + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') +pl.show() |