diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2018-09-14 08:08:00 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-14 08:08:00 +0200 |
commit | 4367a343aeb0ceccbb99acc0f92797af020bb537 (patch) | |
tree | 53eebefbe8e13d94157be88167c7b6df4c78bdc1 /ot | |
parent | da5d07b4949877148f1582a9f0649c34282afa30 (diff) | |
parent | 2b8b18082257edcbbfe503ef2643f235161930b7 (diff) |
Merge pull request #65 from kilianFatras/stochastic_OT
better implementation on stocjastic gradient updates
Diffstat (limited to 'ot')
-rw-r--r-- | ot/stochastic.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/stochastic.py b/ot/stochastic.py index a369ba8..ec53015 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -617,8 +617,8 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr): update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha, cur_beta, batch_size, batch_alpha, batch_beta) - cur_alpha += (lr / k) * update_alpha - cur_beta += (lr / k) * update_beta + cur_alpha[batch_alpha] += (lr / k) * update_alpha[batch_alpha] + cur_beta[batch_beta] += (lr / k) * update_beta[batch_beta] return cur_alpha, cur_beta |