summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-11-19 11:17:07 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-11-19 11:17:07 +0100
commit93db239e1156ad1db8edbb13c1ecde973ce009c0 (patch)
tree14101caa2699d0b90d165303658f1540d1e87a5c /ot/bregman.py
parent87930c4bcddfded480983343ecc68c6b94bcce14 (diff)
remove W605 errors
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index d1057ff..43340f7 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -370,9 +370,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
v = np.divide(b, KtransposeU)
u = 1. / np.dot(Kp, v)
- if (np.any(KtransposeU == 0) or
- np.any(np.isnan(u)) or np.any(np.isnan(v)) or
- np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (np.any(KtransposeU == 0)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
print('Warning: numerical errors at iteration', cpt)
@@ -683,13 +683,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) -
- beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((na, 1))
+ - beta.reshape((1, nb))) / reg)
def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) /
- reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
+ return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb)))
+ / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
# print(np.min(K))
@@ -899,8 +899,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) -
- beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((na, 1))
+ - beta.reshape((1, nb))) / reg)
# print(np.min(K))
def get_reg(n): # exponential decreasing