diff options
Diffstat (limited to 'ot/stochastic.py')
-rw-r--r-- | ot/stochastic.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/ot/stochastic.py b/ot/stochastic.py index 0788f61..e33f6a0 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -435,7 +435,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, ############################################################################## -def batch_grad_dual(M, reg, a, b, alpha, beta, batch_size, batch_alpha, +def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, batch_beta): ''' Computes the partial gradient of F_\W_varepsilon @@ -528,7 +528,7 @@ def batch_grad_dual(M, reg, a, b, alpha, beta, batch_size, batch_alpha, return grad_alpha, grad_beta -def sgd_entropic_regularization(M, reg, a, b, batch_size, numItermax, lr): +def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr): ''' Compute the sgd algorithm to solve the regularized discrete measures optimal transport dual problem @@ -612,7 +612,7 @@ def sgd_entropic_regularization(M, reg, a, b, batch_size, numItermax, lr): k = np.sqrt(cur_iter / 100 + 1) batch_alpha = np.random.choice(n_source, batch_size, replace=False) batch_beta = np.random.choice(n_target, batch_size, replace=False) - update_alpha, update_beta = batch_grad_dual(M, reg, a, b, cur_alpha, + update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha, cur_beta, batch_size, batch_alpha, batch_beta) cur_alpha += (lr / k) * update_alpha @@ -698,7 +698,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, arXiv preprint arxiv:1711.02283. ''' - opt_alpha, opt_beta = sgd_entropic_regularization(M, reg, a, b, batch_size, + opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr) pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) * a[:, None] * b[None, :]) |