summaryrefslogtreecommitdiff
path: root/ot/stochastic.py
diff options
context:
space:
mode:
authorKilian Fatras <kilianfatras@dhcp-206-12-53-92.eduroam.wireless.ubc.ca>2018-06-21 17:18:26 -0700
committerKilian Fatras <kilianfatras@dhcp-206-12-53-92.eduroam.wireless.ubc.ca>2018-06-21 17:18:26 -0700
commit6777ffd5c8457faac4467e58ba9edbcf2f86961b (patch)
tree286c1fdb3902031d47cc5e1bdefab97ab9c9da22 /ot/stochastic.py
parent7073e417bed151976c62fe20d1ba69abd30e7758 (diff)
gave better step size ASGD & SAG
Diffstat (limited to 'ot/stochastic.py')
-rw-r--r--ot/stochastic.py37
1 files changed, 16 insertions, 21 deletions
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 57b96b7..ab88cd0 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -56,7 +56,6 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
>>> n_target = 4
>>> reg = 1
>>> numItermax = 300000
- >>> lr = 1
>>> a = ot.utils.unif(n_source)
>>> b = ot.utils.unif(n_target)
>>> rng = np.random.RandomState(0)
@@ -65,8 +64,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
>>> M = ot.dist(X_source, Y_target)
>>> method = "ASGD"
>>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
- method, numItermax,
- lr)
+ method, numItermax)
>>> print(asgd_pi)
References
@@ -85,7 +83,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
return b - khi
-def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
+def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
'''
Compute the SAG algorithm to solve the regularized discrete measures
optimal transport max problem
@@ -134,17 +132,15 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
>>> n_target = 4
>>> reg = 1
>>> numItermax = 300000
- >>> lr = 1
>>> a = ot.utils.unif(n_source)
>>> b = ot.utils.unif(n_target)
>>> rng = np.random.RandomState(0)
>>> X_source = rng.randn(n_source, 2)
>>> Y_target = rng.randn(n_target, 2)
>>> M = ot.dist(X_source, Y_target)
- >>> method = "SAG"
- >>> sag_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
- method, numItermax,
- lr)
+ >>> method = "ASGD"
+ >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
+ method, numItermax)
>>> print(asgd_pi)
References
@@ -156,6 +152,8 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
arXiv preprint arxiv:1605.08527.
'''
+ if lr is None:
+ lr = 1. / max(a)
n_source = np.shape(M)[0]
n_target = np.shape(M)[1]
cur_beta = np.zeros(n_target)
@@ -171,7 +169,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
return cur_beta
-def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
+def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
'''
Compute the ASGD algorithm to solve the regularized semi contibous measures
optimal transport max problem
@@ -219,7 +217,6 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
>>> n_target = 4
>>> reg = 1
>>> numItermax = 300000
- >>> lr = 1
>>> a = ot.utils.unif(n_source)
>>> b = ot.utils.unif(n_target)
>>> rng = np.random.RandomState(0)
@@ -228,8 +225,7 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
>>> M = ot.dist(X_source, Y_target)
>>> method = "ASGD"
>>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
- method, numItermax,
- lr)
+ method, numItermax)
>>> print(asgd_pi)
References
@@ -241,6 +237,8 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
arXiv preprint arxiv:1605.08527.
'''
+ if lr is None:
+ lr = 1. / max(a)
n_source = np.shape(M)[0]
n_target = np.shape(M)[1]
cur_beta = np.zeros(n_target)
@@ -296,7 +294,6 @@ def c_transform_entropic(b, M, reg, beta):
>>> n_target = 4
>>> reg = 1
>>> numItermax = 300000
- >>> lr = 1
>>> a = ot.utils.unif(n_source)
>>> b = ot.utils.unif(n_target)
>>> rng = np.random.RandomState(0)
@@ -305,8 +302,7 @@ def c_transform_entropic(b, M, reg, beta):
>>> M = ot.dist(X_source, Y_target)
>>> method = "ASGD"
>>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
- method, numItermax,
- lr)
+ method, numItermax)
>>> print(asgd_pi)
References
@@ -328,7 +324,7 @@ def c_transform_entropic(b, M, reg, beta):
return alpha
-def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
+def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
log=False):
'''
Compute the transportation matrix to solve the regularized discrete
@@ -388,7 +384,6 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
>>> n_target = 4
>>> reg = 1
>>> numItermax = 300000
- >>> lr = 1
>>> a = ot.utils.unif(n_source)
>>> b = ot.utils.unif(n_target)
>>> rng = np.random.RandomState(0)
@@ -397,8 +392,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
>>> M = ot.dist(X_source, Y_target)
>>> method = "ASGD"
>>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
- method, numItermax,
- lr)
+ method, numItermax)
>>> print(asgd_pi)
References
@@ -409,10 +403,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
Advances in Neural Information Processing Systems (2016),
arXiv preprint arxiv:1605.08527.
'''
+
if method.lower() == "sag":
opt_beta = sag_entropic_transport(a, b, M, reg, numItermax, lr)
elif method.lower() == "asgd":
- opt_beta = averaged_sgd_entropic_transport(b, M, reg, numItermax, lr)
+ opt_beta = averaged_sgd_entropic_transport(a, b, M, reg, numItermax, lr)
else:
print("Please, select your method between SAG and ASGD")
return None