summaryrefslogtreecommitdiff
path: root/ot/stochastic.py
diff options
context:
space:
mode:
authorKilian Fatras <kilianfatras@Kilians-MacBook-Air.local>2018-09-13 12:15:39 -0700
committerKilian Fatras <kilianfatras@Kilians-MacBook-Air.local>2018-09-13 12:15:39 -0700
commit2b8b18082257edcbbfe503ef2643f235161930b7 (patch)
treef825c3d0c6a761523a65bbe4ee67fc69efef6572 /ot/stochastic.py
parent63b34bf012076eb89ed112122fdaa65667464ae7 (diff)
better implementation on gradient updates
Diffstat (limited to 'ot/stochastic.py')
-rw-r--r--ot/stochastic.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/stochastic.py b/ot/stochastic.py
index a369ba8..ec53015 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -617,8 +617,8 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha,
cur_beta, batch_size,
batch_alpha, batch_beta)
- cur_alpha += (lr / k) * update_alpha
- cur_beta += (lr / k) * update_beta
+ cur_alpha[batch_alpha] += (lr / k) * update_alpha[batch_alpha]
+ cur_beta[batch_beta] += (lr / k) * update_beta[batch_beta]
return cur_alpha, cur_beta