diff options
-rw-r--r-- | ot/optim.py | 5 | ||||
-rw-r--r-- | test/test_optim.py | 10 |
2 files changed, 14 insertions, 1 deletions
diff --git a/ot/optim.py b/ot/optim.py index e7e6e65..1902907 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -69,7 +69,10 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, alpha, phi1 = scalar_search_armijo( phi, phi0, derphi0, c1=c1, alpha0=alpha0) - return min(1, alpha), fc[0], phi1 + # scalar_search_armijo can return alpha > 1 + if alpha is not None: + alpha = min(1, alpha) + return alpha, fc[0], phi1 def solve_linesearch(cost, G, deltaG, Mi, f_val, diff --git a/test/test_optim.py b/test/test_optim.py index 87b0268..48de38a 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -104,3 +104,13 @@ def test_solve_1d_linesearch_quad_funct(): np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5) np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0) np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1) + + +def test_line_search_armijo(): + xk = np.array([[0.25, 0.25], [0.25, 0.25]]) + pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) + gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) + old_fval = -123 + # Should not throw an exception and return None for alpha + alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval) + assert alpha is None |