diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-09-28 16:34:28 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-28 16:34:28 +0200 |
commit | 7dde9e8e4b6aae756e103d49198caaa4f24150e3 (patch) | |
tree | 3961588cfe35d371ebf399bd6c138c2a1bcb1697 /ot | |
parent | e0ba31ce39a7d9e65e50ea970a574b3db54e4207 (diff) |
[MRG] Regularized OT (optim.cg) bug solve (#286)
* Line search stops when derphi is 0 instead of bugging out like in some instances
* pep8 compliance
* Tests
Diffstat (limited to 'ot')
-rw-r--r-- | ot/optim.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/ot/optim.py b/ot/optim.py index abe9e6a..0359343 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -178,9 +178,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, numItermaxEmd : int, optional Max number of iterations for emd stopThr : float, optional - Stop threshol on the relative variation (>0) + Stop threshold on the relative variation (>0) stopThr2 : float, optional - Stop threshol on the absolute variation (>0) + Stop threshold on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -249,6 +249,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, # line search alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) + if alpha is None: + alpha = 0.0 G = G + alpha * deltaG @@ -320,9 +322,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax : int, optional Max number of iterations of Sinkhorn stopThr : float, optional - Stop threshol on the relative variation (>0) + Stop threshold on the relative variation (>0) stopThr2 : float, optional - Stop threshol on the absolute variation (>0) + Stop threshold on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional |