diff options
author | Kilian Fatras <kilianfatras@dhcp-206-12-53-210.eduroam.wireless.ubc.ca> | 2018-06-18 17:56:28 -0700 |
---|---|---|
committer | Kilian Fatras <kilianfatras@dhcp-206-12-53-210.eduroam.wireless.ubc.ca> | 2018-06-18 17:56:28 -0700 |
commit | 74cfe5ac77c3e964a85ef90c11d8ebffa16ddcfe (patch) | |
tree | cbff519da38279cc0e0cc4a4e8ab35d0169d16b5 /examples | |
parent | 055417ee06917ff8bac5d07b2d2a17d80e5da4b6 (diff) |
add sgd
Diffstat (limited to 'examples')
-rw-r--r-- | examples/plot_stochastic.py | 95 |
1 files changed, 83 insertions, 12 deletions
diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py index 9071ddb..3fc1955 100644 --- a/examples/plot_stochastic.py +++ b/examples/plot_stochastic.py @@ -18,9 +18,9 @@ import ot ############################################################################# -# COMPUTE TRANSPORTATION MATRIX +# COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM ############################################################################# - +print("------------SEMI-DUAL PROBLEM------------") ############################################################################# # DISCRETE CASE # Sample two discrete measures for the discrete case @@ -48,12 +48,12 @@ M = ot.dist(X_source, Y_target) # Call the "SAG" method to find the transportation matrix in the discrete case # --------------------------------------------- # -# Define the method "SAG", call ot.transportation_matrix_entropic and plot the +# Define the method "SAG", call ot.solve_semi_dual_entropic and plot the # results. method = "SAG" -sag_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method, - numItermax, lr) +sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, + numItermax, lr) print(sag_pi) ############################################################################# @@ -68,8 +68,9 @@ print(sag_pi) n_source = 7 n_target = 4 reg = 1 -numItermax = 500000 +numItermax = 100000 lr = 1 +log = True a = ot.utils.unif(n_source) b = ot.utils.unif(n_target) @@ -85,12 +86,13 @@ M = ot.dist(X_source, Y_target) # case # --------------------------------------------- # -# Define the method "ASGD", call ot.transportation_matrix_entropic and plot the +# Define the method "ASGD", call ot.solve_semi_dual_entropic and plot the # results. method = "ASGD" -asgd_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method, - numItermax, lr) +asgd_pi, log = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, + numItermax, lr, log) +print(log['alpha'], log['beta']) print(asgd_pi) ############################################################################# @@ -100,7 +102,7 @@ print(asgd_pi) # # Call the Sinkhorn algorithm from POT -sinkhorn_pi = ot.sinkhorn(a, b, M, 1) +sinkhorn_pi = ot.sinkhorn(a, b, M, reg) print(sinkhorn_pi) @@ -113,7 +115,7 @@ print(sinkhorn_pi) # ---------------- pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, sag_pi, 'OT matrix SAG') +ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG') pl.show() @@ -122,7 +124,76 @@ pl.show() # ----------------- pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, asgd_pi, 'OT matrix ASGD') +ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD') +pl.show() + + +############################################################################## +# Plot Sinkhorn results +# --------------------- + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') +pl.show() + +############################################################################# +# COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM +############################################################################# +print("------------DUAL PROBLEM------------") +############################################################################# +# SEMICONTINOUS CASE +# Sample one general measure a, one discrete measures b for the semicontinous +# case +# --------------------------------------------- +# +# Define one general measure a, one discrete measures b, the points where +# are defined the source and the target measures and finally the cost matrix c. + +n_source = 7 +n_target = 4 +reg = 1 +numItermax = 100000 +lr = 0.1 +batch_size = 3 +log = True + +a = ot.utils.unif(n_source) +b = ot.utils.unif(n_target) + +rng = np.random.RandomState(0) +X_source = rng.randn(n_source, 2) +Y_target = rng.randn(n_target, 2) +M = ot.dist(X_source, Y_target) + +############################################################################# +# +# Call the "SGD" dual method to find the transportation matrix in the semicontinous +# case +# --------------------------------------------- +# +# Call ot.solve_dual_entropic and plot the results. + +sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, + numItermax, lr, log) +print(log['alpha'], log['beta']) +print(sgd_dual_pi) + +############################################################################# +# +# Compare the results with the Sinkhorn algorithm +# --------------------------------------------- +# +# Call the Sinkhorn algorithm from POT + +sinkhorn_pi = ot.sinkhorn(a, b, M, reg) +print(sinkhorn_pi) + +############################################################################## +# Plot SGD results +# ----------------- + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD') pl.show() |