diff options
Diffstat (limited to 'test/test_optim.py')
-rw-r--r-- | test/test_optim.py | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/test/test_optim.py b/test/test_optim.py index fd194c2..94995d5 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -114,3 +114,28 @@ def test_line_search_armijo(): # Should not throw an exception and return None for alpha alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval) assert alpha is None + + # check line search armijo + def f(x): + return np.sum((x - 5.0) ** 2) + + def grad(x): + return 2 * (x - 5.0) + + xk = np.array([[[-5.0, -5.0]]]) + pk = np.array([[[100.0, 100.0]]]) + gfk = grad(xk) + old_fval = f(xk) + + # chech the case where the optimum is on the direction + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) + np.testing.assert_allclose(alpha, 0.1) + + # check the case where the direction is not far enough + pk = np.array([[[3.0, 3.0]]]) + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0) + np.testing.assert_allclose(alpha, 1.0) + + # check the case where the checking the wrong direction + alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval) + assert alpha <= 0 |