summaryrefslogtreecommitdiff
path: root/ot/stochastic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/stochastic.py')
-rw-r--r--ot/stochastic.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 4795d88..0db39c8 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -536,10 +536,10 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
a[batch_alpha, None] * b[None, batch_beta])
grad_beta = np.zeros(np.shape(M)[1])
grad_alpha = np.zeros(np.shape(M)[0])
- grad_beta[batch_beta] = (b[batch_beta] * len(batch_alpha) / np.shape(M)[0] +
- G.sum(0))
- grad_alpha[batch_alpha] = (a[batch_alpha] * len(batch_beta) /
- np.shape(M)[1] + G.sum(1))
+ grad_beta[batch_beta] = (b[batch_beta] * len(batch_alpha) / np.shape(M)[0]
+ + G.sum(0))
+ grad_alpha[batch_alpha] = (a[batch_alpha] * len(batch_beta)
+ / np.shape(M)[1] + G.sum(1))
return grad_alpha, grad_beta