summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-09-24 10:23:02 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-09-24 10:23:02 +0200
commit1d494107611c2e6e2249b7a624e64cec6357b4bd (patch)
tree83def7c63ad8b200709336248af04f5588f3eaf5 /ot
parent55e8392993919d3c67538756663abd943d3bb491 (diff)
implement for loop
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 6e446a1..05f7c75 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -520,7 +520,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
log['u'] = u
log['v'] = v
- while i < numItermax and stopThr_val > stopThr:
+ for i in range(numItermax):
i += 1
i_1 = np.argmax(np.abs(viol))
i_2 = np.argmax(np.abs(viol_2))
@@ -547,6 +547,11 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
#print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
+ if stopThr_val <= stopThr:
+ break
+ else:
+ print('Warning: Algorithm did not converge')
+
if log:
log['u'] = u
log['v'] = v