summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2018-09-14 08:08:00 +0200
committerGitHub <noreply@github.com>2018-09-14 08:08:00 +0200
commit4367a343aeb0ceccbb99acc0f92797af020bb537 (patch)
tree53eebefbe8e13d94157be88167c7b6df4c78bdc1 /ot
parentda5d07b4949877148f1582a9f0649c34282afa30 (diff)
parent2b8b18082257edcbbfe503ef2643f235161930b7 (diff)
Merge pull request #65 from kilianFatras/stochastic_OT
better implementation on stocjastic gradient updates
Diffstat (limited to 'ot')
-rw-r--r--ot/stochastic.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/stochastic.py b/ot/stochastic.py
index a369ba8..ec53015 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -617,8 +617,8 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha,
cur_beta, batch_size,
batch_alpha, batch_beta)
- cur_alpha += (lr / k) * update_alpha
- cur_beta += (lr / k) * update_beta
+ cur_alpha[batch_alpha] += (lr / k) * update_alpha[batch_alpha]
+ cur_beta[batch_beta] += (lr / k) * update_beta[batch_beta]
return cur_alpha, cur_beta