summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-26 12:22:00 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-26 12:22:00 +0200
commit84aa3183491260b9c3dbb9f928499cc18e5341c1 (patch)
tree08bada6463e76c8dbf3f873fa18834c2712bb0bb /ot/bregman.py
parent251af8eec2b39e74000242cbf5bff5e13910cfe8 (diff)
pep8
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 71a5548..929388e 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -108,7 +108,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ver
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
def sink():
- return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
+ return sinkhorn_epsilon_scaling(
+ a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
else:
print('Warning : unknown method using classic Sinkhorn Knopp')
@@ -216,7 +217,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ve
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
def sink():
- return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
+ return sinkhorn_epsilon_scaling(
+ a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
else:
print('Warning : unknown method using classic Sinkhorn Knopp')
@@ -593,7 +595,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa
cpt = cpt + 1
- #print('err=',err,' cpt=',cpt)
+ # print('err=',err,' cpt=',cpt)
if log:
log['logu'] = alpha / reg + np.log(u)
log['logv'] = beta / reg + np.log(v)
@@ -778,7 +780,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
loop = False
cpt = cpt + 1
- #print('err=',err,' cpt=',cpt)
+ # print('err=',err,' cpt=',cpt)
if log:
log['alpha'] = alpha
log['beta'] = beta
@@ -965,16 +967,16 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verb
"""
- #M = M/np.median(M)
+ # M = M/np.median(M)
K = np.exp(-M / reg)
- #M0 = M0/np.median(M0)
+ # M0 = M0/np.median(M0)
K0 = np.exp(-M0 / reg0)
old = h0
err = 1
cpt = 0
- #log = {'niter':0, 'all_err':[]}
+ # log = {'niter':0, 'all_err':[]}
if log:
log = {'err': []}