From 6777ffd5c8457faac4467e58ba9edbcf2f86961b Mon Sep 17 00:00:00 2001 From: Kilian Fatras Date: Thu, 21 Jun 2018 17:18:26 -0700 Subject: gave better step size ASGD & SAG --- examples/plot_stochastic.py | 10 ++++------ ot/stochastic.py | 37 ++++++++++++++++--------------------- test/test_stochastic.py | 7 ++----- 3 files changed, 22 insertions(+), 32 deletions(-) diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py index 09b95d0..6274b4c 100644 --- a/examples/plot_stochastic.py +++ b/examples/plot_stochastic.py @@ -32,8 +32,7 @@ print("------------SEMI-DUAL PROBLEM------------") n_source = 7 n_target = 4 reg = 1 -numItermax = 10000 -lr = 0.1 +numItermax = 1000 a = ot.utils.unif(n_source) b = ot.utils.unif(n_target) @@ -53,7 +52,7 @@ M = ot.dist(X_source, Y_target) method = "SAG" sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, - numItermax, lr) + numItermax) print(sag_pi) ############################################################################# @@ -68,8 +67,7 @@ print(sag_pi) n_source = 7 n_target = 4 reg = 1 -numItermax = 100000 -lr = 1 +numItermax = 1000 log = True a = ot.utils.unif(n_source) @@ -91,7 +89,7 @@ M = ot.dist(X_source, Y_target) method = "ASGD" asgd_pi, log = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, - numItermax, lr, log) + numItermax, log) print(log['alpha'], log['beta']) print(asgd_pi) diff --git a/ot/stochastic.py b/ot/stochastic.py index 57b96b7..ab88cd0 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -56,7 +56,6 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): >>> n_target = 4 >>> reg = 1 >>> numItermax = 300000 - >>> lr = 1 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) >>> rng = np.random.RandomState(0) @@ -65,8 +64,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): >>> M = ot.dist(X_source, Y_target) >>> method = "ASGD" >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg, - method, numItermax, - lr) + method, numItermax) >>> print(asgd_pi) References @@ -85,7 +83,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): return b - khi -def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1): +def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None): ''' Compute the SAG algorithm to solve the regularized discrete measures optimal transport max problem @@ -134,17 +132,15 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1): >>> n_target = 4 >>> reg = 1 >>> numItermax = 300000 - >>> lr = 1 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) >>> rng = np.random.RandomState(0) >>> X_source = rng.randn(n_source, 2) >>> Y_target = rng.randn(n_target, 2) >>> M = ot.dist(X_source, Y_target) - >>> method = "SAG" - >>> sag_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg, - method, numItermax, - lr) + >>> method = "ASGD" + >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg, + method, numItermax) >>> print(asgd_pi) References @@ -156,6 +152,8 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1): arXiv preprint arxiv:1605.08527. ''' + if lr is None: + lr = 1. / max(a) n_source = np.shape(M)[0] n_target = np.shape(M)[1] cur_beta = np.zeros(n_target) @@ -171,7 +169,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1): return cur_beta -def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1): +def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None): ''' Compute the ASGD algorithm to solve the regularized semi contibous measures optimal transport max problem @@ -219,7 +217,6 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1): >>> n_target = 4 >>> reg = 1 >>> numItermax = 300000 - >>> lr = 1 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) >>> rng = np.random.RandomState(0) @@ -228,8 +225,7 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1): >>> M = ot.dist(X_source, Y_target) >>> method = "ASGD" >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg, - method, numItermax, - lr) + method, numItermax) >>> print(asgd_pi) References @@ -241,6 +237,8 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1): arXiv preprint arxiv:1605.08527. ''' + if lr is None: + lr = 1. / max(a) n_source = np.shape(M)[0] n_target = np.shape(M)[1] cur_beta = np.zeros(n_target) @@ -296,7 +294,6 @@ def c_transform_entropic(b, M, reg, beta): >>> n_target = 4 >>> reg = 1 >>> numItermax = 300000 - >>> lr = 1 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) >>> rng = np.random.RandomState(0) @@ -305,8 +302,7 @@ def c_transform_entropic(b, M, reg, beta): >>> M = ot.dist(X_source, Y_target) >>> method = "ASGD" >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg, - method, numItermax, - lr) + method, numItermax) >>> print(asgd_pi) References @@ -328,7 +324,7 @@ def c_transform_entropic(b, M, reg, beta): return alpha -def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1, +def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, log=False): ''' Compute the transportation matrix to solve the regularized discrete @@ -388,7 +384,6 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1, >>> n_target = 4 >>> reg = 1 >>> numItermax = 300000 - >>> lr = 1 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) >>> rng = np.random.RandomState(0) @@ -397,8 +392,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1, >>> M = ot.dist(X_source, Y_target) >>> method = "ASGD" >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg, - method, numItermax, - lr) + method, numItermax) >>> print(asgd_pi) References @@ -409,10 +403,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1, Advances in Neural Information Processing Systems (2016), arXiv preprint arxiv:1605.08527. ''' + if method.lower() == "sag": opt_beta = sag_entropic_transport(a, b, M, reg, numItermax, lr) elif method.lower() == "asgd": - opt_beta = averaged_sgd_entropic_transport(b, M, reg, numItermax, lr) + opt_beta = averaged_sgd_entropic_transport(a, b, M, reg, numItermax, lr) else: print("Please, select your method between SAG and ASGD") return None diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 3cb51ff..f315c88 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -63,7 +63,6 @@ def test_stochastic_asgd(): n = 15 reg = 1 numItermax = 300000 - lr = 1 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -72,8 +71,7 @@ def test_stochastic_asgd(): M = ot.dist(x, x) G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd", - numItermax=numItermax, - lr=lr) + numItermax=numItermax) # check constratints np.testing.assert_allclose( @@ -95,7 +93,6 @@ def test_sag_asgd_sinkhorn(): n = 15 reg = 1 nb_iter = 300000 - lr = 1 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -104,7 +101,7 @@ def test_sag_asgd_sinkhorn(): M = ot.dist(x, x) G_asgd = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd", - numItermax=nb_iter, lr=lr) + numItermax=nb_iter) G_sag = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "sag", numItermax=nb_iter) G_sinkhorn = ot.sinkhorn(u, u, M, reg) -- cgit v1.2.3