summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorKilian Fatras <kilianfatras@dhcp-206-12-53-99.eduroam.wireless.ubc.ca>2018-06-15 18:53:54 -0700
committerKilian Fatras <kilianfatras@dhcp-206-12-53-99.eduroam.wireless.ubc.ca>2018-06-15 18:53:54 -0700
commitc8eda449b2c6b39e9d57d1b5b2c39e43f2925892 (patch)
tree19e241788b920c8c5c181ca6562553a2e3bd7468 /examples
parent90efa5a8b189214d1aeb81920b2bb04ce0c261ca (diff)
add problems solved in doc
Diffstat (limited to 'examples')
-rw-r--r--examples/plot_stochastic.py135
1 files changed, 135 insertions, 0 deletions
diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py
new file mode 100644
index 0000000..9071ddb
--- /dev/null
+++ b/examples/plot_stochastic.py
@@ -0,0 +1,135 @@
+"""
+==========================
+Stochastic examples
+==========================
+
+This example is designed to show how to use the stochatic optimization
+algorithms for descrete and semicontinous measures from the POT library.
+
+"""
+
+# Author: Kilian Fatras <kilian.fatras@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import numpy as np
+import ot
+
+
+#############################################################################
+# COMPUTE TRANSPORTATION MATRIX
+#############################################################################
+
+#############################################################################
+# DISCRETE CASE
+# Sample two discrete measures for the discrete case
+# ---------------------------------------------
+#
+# Define 2 discrete measures a and b, the points where are defined the source
+# and the target measures and finally the cost matrix c.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 10000
+lr = 0.1
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+#
+# Call the "SAG" method to find the transportation matrix in the discrete case
+# ---------------------------------------------
+#
+# Define the method "SAG", call ot.transportation_matrix_entropic and plot the
+# results.
+
+method = "SAG"
+sag_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method,
+ numItermax, lr)
+print(sag_pi)
+
+#############################################################################
+# SEMICONTINOUS CASE
+# Sample one general measure a, one discrete measures b for the semicontinous
+# case
+# ---------------------------------------------
+#
+# Define one general measure a, one discrete measures b, the points where
+# are defined the source and the target measures and finally the cost matrix c.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 500000
+lr = 1
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+#
+# Call the "ASGD" method to find the transportation matrix in the semicontinous
+# case
+# ---------------------------------------------
+#
+# Define the method "ASGD", call ot.transportation_matrix_entropic and plot the
+# results.
+
+method = "ASGD"
+asgd_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method,
+ numItermax, lr)
+print(asgd_pi)
+
+#############################################################################
+#
+# Compare the results with the Sinkhorn algorithm
+# ---------------------------------------------
+#
+# Call the Sinkhorn algorithm from POT
+
+sinkhorn_pi = ot.sinkhorn(a, b, M, 1)
+print(sinkhorn_pi)
+
+
+##############################################################################
+# PLOT TRANSPORTATION MATRIX
+##############################################################################
+
+##############################################################################
+# Plot SAG results
+# ----------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sag_pi, 'OT matrix SAG')
+pl.show()
+
+
+##############################################################################
+# Plot ASGD results
+# -----------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, asgd_pi, 'OT matrix ASGD')
+pl.show()
+
+
+##############################################################################
+# Plot Sinkhorn results
+# ---------------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
+pl.show()