diff options
author | Kilian Fatras <kilianfatras@dhcp-206-12-53-92.eduroam.wireless.ubc.ca> | 2018-06-21 17:18:26 -0700 |
---|---|---|
committer | Kilian Fatras <kilianfatras@dhcp-206-12-53-92.eduroam.wireless.ubc.ca> | 2018-06-21 17:18:26 -0700 |
commit | 6777ffd5c8457faac4467e58ba9edbcf2f86961b (patch) | |
tree | 286c1fdb3902031d47cc5e1bdefab97ab9c9da22 /ot/stochastic.py | |
parent | 7073e417bed151976c62fe20d1ba69abd30e7758 (diff) |
gave better step size ASGD & SAG
Diffstat (limited to 'ot/stochastic.py')
-rw-r--r-- | ot/stochastic.py | 37 |
1 files changed, 16 insertions, 21 deletions
diff --git a/ot/stochastic.py b/ot/stochastic.py index 57b96b7..ab88cd0 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -56,7 +56,6 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): >>> n_target = 4 >>> reg = 1 >>> numItermax = 300000 - >>> lr = 1 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) >>> rng = np.random.RandomState(0) @@ -65,8 +64,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): >>> M = ot.dist(X_source, Y_target) >>> method = "ASGD" >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg, - method, numItermax, - lr) + method, numItermax) >>> print(asgd_pi) References @@ -85,7 +83,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): return b - khi -def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1): +def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None): ''' Compute the SAG algorithm to solve the regularized discrete measures optimal transport max problem @@ -134,17 +132,15 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1): >>> n_target = 4 >>> reg = 1 >>> numItermax = 300000 - >>> lr = 1 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) >>> rng = np.random.RandomState(0) >>> X_source = rng.randn(n_source, 2) >>> Y_target = rng.randn(n_target, 2) >>> M = ot.dist(X_source, Y_target) - >>> method = "SAG" - >>> sag_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg, - method, numItermax, - lr) + >>> method = "ASGD" + >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg, + method, numItermax) >>> print(asgd_pi) References @@ -156,6 +152,8 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1): arXiv preprint arxiv:1605.08527. ''' + if lr is None: + lr = 1. / max(a) n_source = np.shape(M)[0] n_target = np.shape(M)[1] cur_beta = np.zeros(n_target) @@ -171,7 +169,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1): return cur_beta -def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1): +def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None): ''' Compute the ASGD algorithm to solve the regularized semi contibous measures optimal transport max problem @@ -219,7 +217,6 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1): >>> n_target = 4 >>> reg = 1 >>> numItermax = 300000 - >>> lr = 1 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) >>> rng = np.random.RandomState(0) @@ -228,8 +225,7 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1): >>> M = ot.dist(X_source, Y_target) >>> method = "ASGD" >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg, - method, numItermax, - lr) + method, numItermax) >>> print(asgd_pi) References @@ -241,6 +237,8 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1): arXiv preprint arxiv:1605.08527. ''' + if lr is None: + lr = 1. / max(a) n_source = np.shape(M)[0] n_target = np.shape(M)[1] cur_beta = np.zeros(n_target) @@ -296,7 +294,6 @@ def c_transform_entropic(b, M, reg, beta): >>> n_target = 4 >>> reg = 1 >>> numItermax = 300000 - >>> lr = 1 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) >>> rng = np.random.RandomState(0) @@ -305,8 +302,7 @@ def c_transform_entropic(b, M, reg, beta): >>> M = ot.dist(X_source, Y_target) >>> method = "ASGD" >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg, - method, numItermax, - lr) + method, numItermax) >>> print(asgd_pi) References @@ -328,7 +324,7 @@ def c_transform_entropic(b, M, reg, beta): return alpha -def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1, +def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, log=False): ''' Compute the transportation matrix to solve the regularized discrete @@ -388,7 +384,6 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1, >>> n_target = 4 >>> reg = 1 >>> numItermax = 300000 - >>> lr = 1 >>> a = ot.utils.unif(n_source) >>> b = ot.utils.unif(n_target) >>> rng = np.random.RandomState(0) @@ -397,8 +392,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1, >>> M = ot.dist(X_source, Y_target) >>> method = "ASGD" >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg, - method, numItermax, - lr) + method, numItermax) >>> print(asgd_pi) References @@ -409,10 +403,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1, Advances in Neural Information Processing Systems (2016), arXiv preprint arxiv:1605.08527. ''' + if method.lower() == "sag": opt_beta = sag_entropic_transport(a, b, M, reg, numItermax, lr) elif method.lower() == "asgd": - opt_beta = averaged_sgd_entropic_transport(b, M, reg, numItermax, lr) + opt_beta = averaged_sgd_entropic_transport(a, b, M, reg, numItermax, lr) else: print("Please, select your method between SAG and ASGD") return None |