diff options
author | Kilian Fatras <kilianfatras@Kilians-MacBook-Air.local> | 2018-09-13 12:15:39 -0700 |
---|---|---|
committer | Kilian Fatras <kilianfatras@Kilians-MacBook-Air.local> | 2018-09-13 12:15:39 -0700 |
commit | 2b8b18082257edcbbfe503ef2643f235161930b7 (patch) | |
tree | f825c3d0c6a761523a65bbe4ee67fc69efef6572 /ot/stochastic.py | |
parent | 63b34bf012076eb89ed112122fdaa65667464ae7 (diff) |
better implementation on gradient updates
Diffstat (limited to 'ot/stochastic.py')
-rw-r--r-- | ot/stochastic.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/stochastic.py b/ot/stochastic.py index a369ba8..ec53015 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -617,8 +617,8 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr): update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha, cur_beta, batch_size, batch_alpha, batch_beta) - cur_alpha += (lr / k) * update_alpha - cur_beta += (lr / k) * update_beta + cur_alpha[batch_alpha] += (lr / k) * update_alpha[batch_alpha] + cur_beta[batch_beta] += (lr / k) * update_beta[batch_beta] return cur_alpha, cur_beta |