summaryrefslogtreecommitdiff
path: root/examples/plot_stochastic.py
diff options
context:
space:
mode:
authorKilian Fatras <kilianfatras@dhcp-206-12-53-210.eduroam.wireless.ubc.ca>2018-06-18 17:56:28 -0700
committerKilian Fatras <kilianfatras@dhcp-206-12-53-210.eduroam.wireless.ubc.ca>2018-06-18 17:56:28 -0700
commit74cfe5ac77c3e964a85ef90c11d8ebffa16ddcfe (patch)
treecbff519da38279cc0e0cc4a4e8ab35d0169d16b5 /examples/plot_stochastic.py
parent055417ee06917ff8bac5d07b2d2a17d80e5da4b6 (diff)
add sgd
Diffstat (limited to 'examples/plot_stochastic.py')
-rw-r--r--examples/plot_stochastic.py95
1 files changed, 83 insertions, 12 deletions
diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py
index 9071ddb..3fc1955 100644
--- a/examples/plot_stochastic.py
+++ b/examples/plot_stochastic.py
@@ -18,9 +18,9 @@ import ot
#############################################################################
-# COMPUTE TRANSPORTATION MATRIX
+# COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM
#############################################################################
-
+print("------------SEMI-DUAL PROBLEM------------")
#############################################################################
# DISCRETE CASE
# Sample two discrete measures for the discrete case
@@ -48,12 +48,12 @@ M = ot.dist(X_source, Y_target)
# Call the "SAG" method to find the transportation matrix in the discrete case
# ---------------------------------------------
#
-# Define the method "SAG", call ot.transportation_matrix_entropic and plot the
+# Define the method "SAG", call ot.solve_semi_dual_entropic and plot the
# results.
method = "SAG"
-sag_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method,
- numItermax, lr)
+sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
+ numItermax, lr)
print(sag_pi)
#############################################################################
@@ -68,8 +68,9 @@ print(sag_pi)
n_source = 7
n_target = 4
reg = 1
-numItermax = 500000
+numItermax = 100000
lr = 1
+log = True
a = ot.utils.unif(n_source)
b = ot.utils.unif(n_target)
@@ -85,12 +86,13 @@ M = ot.dist(X_source, Y_target)
# case
# ---------------------------------------------
#
-# Define the method "ASGD", call ot.transportation_matrix_entropic and plot the
+# Define the method "ASGD", call ot.solve_semi_dual_entropic and plot the
# results.
method = "ASGD"
-asgd_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method,
- numItermax, lr)
+asgd_pi, log = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
+ numItermax, lr, log)
+print(log['alpha'], log['beta'])
print(asgd_pi)
#############################################################################
@@ -100,7 +102,7 @@ print(asgd_pi)
#
# Call the Sinkhorn algorithm from POT
-sinkhorn_pi = ot.sinkhorn(a, b, M, 1)
+sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
print(sinkhorn_pi)
@@ -113,7 +115,7 @@ print(sinkhorn_pi)
# ----------------
pl.figure(4, figsize=(5, 5))
-ot.plot.plot1D_mat(a, b, sag_pi, 'OT matrix SAG')
+ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')
pl.show()
@@ -122,7 +124,76 @@ pl.show()
# -----------------
pl.figure(4, figsize=(5, 5))
-ot.plot.plot1D_mat(a, b, asgd_pi, 'OT matrix ASGD')
+ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')
+pl.show()
+
+
+##############################################################################
+# Plot Sinkhorn results
+# ---------------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
+pl.show()
+
+#############################################################################
+# COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM
+#############################################################################
+print("------------DUAL PROBLEM------------")
+#############################################################################
+# SEMICONTINOUS CASE
+# Sample one general measure a, one discrete measures b for the semicontinous
+# case
+# ---------------------------------------------
+#
+# Define one general measure a, one discrete measures b, the points where
+# are defined the source and the target measures and finally the cost matrix c.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 100000
+lr = 0.1
+batch_size = 3
+log = True
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+#
+# Call the "SGD" dual method to find the transportation matrix in the semicontinous
+# case
+# ---------------------------------------------
+#
+# Call ot.solve_dual_entropic and plot the results.
+
+sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size,
+ numItermax, lr, log)
+print(log['alpha'], log['beta'])
+print(sgd_dual_pi)
+
+#############################################################################
+#
+# Compare the results with the Sinkhorn algorithm
+# ---------------------------------------------
+#
+# Call the Sinkhorn algorithm from POT
+
+sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
+print(sinkhorn_pi)
+
+##############################################################################
+# Plot SGD results
+# -----------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')
pl.show()