summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKilian Fatras <kilianfatras@dhcp-206-12-53-92.eduroam.wireless.ubc.ca>2018-06-21 17:18:26 -0700
committerKilian Fatras <kilianfatras@dhcp-206-12-53-92.eduroam.wireless.ubc.ca>2018-06-21 17:18:26 -0700
commit6777ffd5c8457faac4467e58ba9edbcf2f86961b (patch)
tree286c1fdb3902031d47cc5e1bdefab97ab9c9da22
parent7073e417bed151976c62fe20d1ba69abd30e7758 (diff)
gave better step size ASGD & SAG
-rw-r--r--examples/plot_stochastic.py10
-rw-r--r--ot/stochastic.py37
-rw-r--r--test/test_stochastic.py7
3 files changed, 22 insertions, 32 deletions
diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py
index 09b95d0..6274b4c 100644
--- a/examples/plot_stochastic.py
+++ b/examples/plot_stochastic.py
@@ -32,8 +32,7 @@ print("------------SEMI-DUAL PROBLEM------------")
n_source = 7
n_target = 4
reg = 1
-numItermax = 10000
-lr = 0.1
+numItermax = 1000
a = ot.utils.unif(n_source)
b = ot.utils.unif(n_target)
@@ -53,7 +52,7 @@ M = ot.dist(X_source, Y_target)
method = "SAG"
sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
- numItermax, lr)
+ numItermax)
print(sag_pi)
#############################################################################
@@ -68,8 +67,7 @@ print(sag_pi)
n_source = 7
n_target = 4
reg = 1
-numItermax = 100000
-lr = 1
+numItermax = 1000
log = True
a = ot.utils.unif(n_source)
@@ -91,7 +89,7 @@ M = ot.dist(X_source, Y_target)
method = "ASGD"
asgd_pi, log = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
- numItermax, lr, log)
+ numItermax, log)
print(log['alpha'], log['beta'])
print(asgd_pi)
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 57b96b7..ab88cd0 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -56,7 +56,6 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
>>> n_target = 4
>>> reg = 1
>>> numItermax = 300000
- >>> lr = 1
>>> a = ot.utils.unif(n_source)
>>> b = ot.utils.unif(n_target)
>>> rng = np.random.RandomState(0)
@@ -65,8 +64,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
>>> M = ot.dist(X_source, Y_target)
>>> method = "ASGD"
>>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
- method, numItermax,
- lr)
+ method, numItermax)
>>> print(asgd_pi)
References
@@ -85,7 +83,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
return b - khi
-def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
+def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
'''
Compute the SAG algorithm to solve the regularized discrete measures
optimal transport max problem
@@ -134,17 +132,15 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
>>> n_target = 4
>>> reg = 1
>>> numItermax = 300000
- >>> lr = 1
>>> a = ot.utils.unif(n_source)
>>> b = ot.utils.unif(n_target)
>>> rng = np.random.RandomState(0)
>>> X_source = rng.randn(n_source, 2)
>>> Y_target = rng.randn(n_target, 2)
>>> M = ot.dist(X_source, Y_target)
- >>> method = "SAG"
- >>> sag_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
- method, numItermax,
- lr)
+ >>> method = "ASGD"
+ >>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
+ method, numItermax)
>>> print(asgd_pi)
References
@@ -156,6 +152,8 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
arXiv preprint arxiv:1605.08527.
'''
+ if lr is None:
+ lr = 1. / max(a)
n_source = np.shape(M)[0]
n_target = np.shape(M)[1]
cur_beta = np.zeros(n_target)
@@ -171,7 +169,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=0.1):
return cur_beta
-def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
+def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
'''
Compute the ASGD algorithm to solve the regularized semi contibous measures
optimal transport max problem
@@ -219,7 +217,6 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
>>> n_target = 4
>>> reg = 1
>>> numItermax = 300000
- >>> lr = 1
>>> a = ot.utils.unif(n_source)
>>> b = ot.utils.unif(n_target)
>>> rng = np.random.RandomState(0)
@@ -228,8 +225,7 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
>>> M = ot.dist(X_source, Y_target)
>>> method = "ASGD"
>>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
- method, numItermax,
- lr)
+ method, numItermax)
>>> print(asgd_pi)
References
@@ -241,6 +237,8 @@ def averaged_sgd_entropic_transport(b, M, reg, numItermax=300000, lr=1):
arXiv preprint arxiv:1605.08527.
'''
+ if lr is None:
+ lr = 1. / max(a)
n_source = np.shape(M)[0]
n_target = np.shape(M)[1]
cur_beta = np.zeros(n_target)
@@ -296,7 +294,6 @@ def c_transform_entropic(b, M, reg, beta):
>>> n_target = 4
>>> reg = 1
>>> numItermax = 300000
- >>> lr = 1
>>> a = ot.utils.unif(n_source)
>>> b = ot.utils.unif(n_target)
>>> rng = np.random.RandomState(0)
@@ -305,8 +302,7 @@ def c_transform_entropic(b, M, reg, beta):
>>> M = ot.dist(X_source, Y_target)
>>> method = "ASGD"
>>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
- method, numItermax,
- lr)
+ method, numItermax)
>>> print(asgd_pi)
References
@@ -328,7 +324,7 @@ def c_transform_entropic(b, M, reg, beta):
return alpha
-def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
+def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
log=False):
'''
Compute the transportation matrix to solve the regularized discrete
@@ -388,7 +384,6 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
>>> n_target = 4
>>> reg = 1
>>> numItermax = 300000
- >>> lr = 1
>>> a = ot.utils.unif(n_source)
>>> b = ot.utils.unif(n_target)
>>> rng = np.random.RandomState(0)
@@ -397,8 +392,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
>>> M = ot.dist(X_source, Y_target)
>>> method = "ASGD"
>>> asgd_pi = stochastic.solve_semi_dual_entropic(a, b, M, reg,
- method, numItermax,
- lr)
+ method, numItermax)
>>> print(asgd_pi)
References
@@ -409,10 +403,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=0.1,
Advances in Neural Information Processing Systems (2016),
arXiv preprint arxiv:1605.08527.
'''
+
if method.lower() == "sag":
opt_beta = sag_entropic_transport(a, b, M, reg, numItermax, lr)
elif method.lower() == "asgd":
- opt_beta = averaged_sgd_entropic_transport(b, M, reg, numItermax, lr)
+ opt_beta = averaged_sgd_entropic_transport(a, b, M, reg, numItermax, lr)
else:
print("Please, select your method between SAG and ASGD")
return None
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index 3cb51ff..f315c88 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -63,7 +63,6 @@ def test_stochastic_asgd():
n = 15
reg = 1
numItermax = 300000
- lr = 1
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -72,8 +71,7 @@ def test_stochastic_asgd():
M = ot.dist(x, x)
G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
- numItermax=numItermax,
- lr=lr)
+ numItermax=numItermax)
# check constratints
np.testing.assert_allclose(
@@ -95,7 +93,6 @@ def test_sag_asgd_sinkhorn():
n = 15
reg = 1
nb_iter = 300000
- lr = 1
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -104,7 +101,7 @@ def test_sag_asgd_sinkhorn():
M = ot.dist(x, x)
G_asgd = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
- numItermax=nb_iter, lr=lr)
+ numItermax=nb_iter)
G_sag = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "sag",
numItermax=nb_iter)
G_sinkhorn = ot.sinkhorn(u, u, M, reg)