summaryrefslogtreecommitdiff
path: root/ot/stochastic.py
diff options
context:
space:
mode:
authorKilian Fatras <kilianfatras@dhcp-206-12-53-20.eduroam.wireless.ubc.ca>2018-08-28 18:51:28 -0700
committerKilian Fatras <kilianfatras@dhcp-206-12-53-20.eduroam.wireless.ubc.ca>2018-08-28 18:51:28 -0700
commit37e3b29595223399ebe4710ac2bb061004814118 (patch)
treec76a57b9a2412af918f2be87137c93e154264ac7 /ot/stochastic.py
parentcd193f78d392143ea9421da0f7e55ca8b75b8a0d (diff)
fixed argument functions
Diffstat (limited to 'ot/stochastic.py')
-rw-r--r--ot/stochastic.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 0788f61..e33f6a0 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -435,7 +435,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
##############################################################################
-def batch_grad_dual(M, reg, a, b, alpha, beta, batch_size, batch_alpha,
+def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
batch_beta):
'''
Computes the partial gradient of F_\W_varepsilon
@@ -528,7 +528,7 @@ def batch_grad_dual(M, reg, a, b, alpha, beta, batch_size, batch_alpha,
return grad_alpha, grad_beta
-def sgd_entropic_regularization(M, reg, a, b, batch_size, numItermax, lr):
+def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
'''
Compute the sgd algorithm to solve the regularized discrete measures
optimal transport dual problem
@@ -612,7 +612,7 @@ def sgd_entropic_regularization(M, reg, a, b, batch_size, numItermax, lr):
k = np.sqrt(cur_iter / 100 + 1)
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
batch_beta = np.random.choice(n_target, batch_size, replace=False)
- update_alpha, update_beta = batch_grad_dual(M, reg, a, b, cur_alpha,
+ update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha,
cur_beta, batch_size,
batch_alpha, batch_beta)
cur_alpha += (lr / k) * update_alpha
@@ -698,7 +698,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
arXiv preprint arxiv:1711.02283.
'''
- opt_alpha, opt_beta = sgd_entropic_regularization(M, reg, a, b, batch_size,
+ opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size,
numItermax, lr)
pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) *
a[:, None] * b[None, :])