summaryrefslogtreecommitdiff
path: root/ot/stochastic.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-11-19 11:17:07 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-11-19 11:17:07 +0100
commit93db239e1156ad1db8edbb13c1ecde973ce009c0 (patch)
tree14101caa2699d0b90d165303658f1540d1e87a5c /ot/stochastic.py
parent87930c4bcddfded480983343ecc68c6b94bcce14 (diff)
remove W605 errors
Diffstat (limited to 'ot/stochastic.py')
-rw-r--r--ot/stochastic.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/ot/stochastic.py b/ot/stochastic.py
index ec53015..1376884 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -418,8 +418,8 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
return None
opt_alpha = c_transform_entropic(b, M, reg, opt_beta)
- pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) *
- a[:, None] * b[None, :])
+ pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg)
+ * a[:, None] * b[None, :])
if log:
log = {}
@@ -520,15 +520,15 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
arXiv preprint arxiv:1711.02283.
'''
- G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta] -
- M[batch_alpha, :][:, batch_beta]) / reg) *
+ G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta]
+ - M[batch_alpha, :][:, batch_beta]) / reg) *
a[batch_alpha, None] * b[None, batch_beta])
grad_beta = np.zeros(np.shape(M)[1])
grad_alpha = np.zeros(np.shape(M)[0])
- grad_beta[batch_beta] = (b[batch_beta] * len(batch_alpha) / np.shape(M)[0] +
- G.sum(0))
- grad_alpha[batch_alpha] = (a[batch_alpha] * len(batch_beta) /
- np.shape(M)[1] + G.sum(1))
+ grad_beta[batch_beta] = (b[batch_beta] * len(batch_alpha) / np.shape(M)[0]
+ + G.sum(0))
+ grad_alpha[batch_alpha] = (a[batch_alpha] * len(batch_beta)
+ / np.shape(M)[1] + G.sum(1))
return grad_alpha, grad_beta
@@ -702,8 +702,8 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size,
numItermax, lr)
- pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) *
- a[:, None] * b[None, :])
+ pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg)
+ * a[:, None] * b[None, :])
if log:
log = {}
log['alpha'] = opt_alpha