summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2021-11-17 11:16:24 +0100
committerGitHub <noreply@github.com>2021-11-17 11:16:24 +0100
commite235b08c2ac86f75b6c1b8e96e503305aa0449e1 (patch)
tree8699a99ef206f38cf4d6673b25f100441fa9cf32
parentf4b363d865a79c07248176c1e36990e0cb6814ea (diff)
[MRG] SinkhornL1L2Transport bug (#312)
* solve bug * Linesearch no longer return None as alpha, only 0
-rw-r--r--ot/optim.py10
-rw-r--r--test/test_optim.py4
2 files changed, 7 insertions, 7 deletions
diff --git a/ot/optim.py b/ot/optim.py
index bd8ca26..cacec53 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -77,10 +77,12 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
- # scalar_search_armijo can return alpha > 1
- if alpha is not None:
+ if alpha is None:
+ return 0., fc[0], phi0
+ else:
+ # scalar_search_armijo can return alpha > 1
alpha = min(1, alpha)
- return alpha, fc[0], phi1
+ return alpha, fc[0], phi1
def solve_linesearch(cost, G, deltaG, Mi, f_val,
@@ -273,8 +275,6 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
# line search
alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
- if alpha is None:
- alpha = 0.0
G = G + alpha * deltaG
diff --git a/test/test_optim.py b/test/test_optim.py
index 4efd9b1..41f9cbe 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -142,7 +142,7 @@ def test_line_search_armijo(nx):
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
old_fval = -123
- # Should not throw an exception and return None for alpha
+ # Should not throw an exception and return 0. for alpha
alpha, a, b = ot.optim.line_search_armijo(
lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval
)
@@ -151,7 +151,7 @@ def test_line_search_armijo(nx):
)
assert a == anp
assert b == bnp
- assert alpha is None
+ assert alpha == 0.
# check line search armijo
def f(x):