summaryrefslogtreecommitdiff
path: root/test/test_optim.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_optim.py')
-rw-r--r--test/test_optim.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/test/test_optim.py b/test/test_optim.py
index fd194c2..94995d5 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -114,3 +114,28 @@ def test_line_search_armijo():
# Should not throw an exception and return None for alpha
alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval)
assert alpha is None
+
+ # check line search armijo
+ def f(x):
+ return np.sum((x - 5.0) ** 2)
+
+ def grad(x):
+ return 2 * (x - 5.0)
+
+ xk = np.array([[[-5.0, -5.0]]])
+ pk = np.array([[[100.0, 100.0]]])
+ gfk = grad(xk)
+ old_fval = f(xk)
+
+ # chech the case where the optimum is on the direction
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
+ np.testing.assert_allclose(alpha, 0.1)
+
+ # check the case where the direction is not far enough
+ pk = np.array([[[3.0, 3.0]]])
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0)
+ np.testing.assert_allclose(alpha, 1.0)
+
+ # check the case where the checking the wrong direction
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval)
+ assert alpha <= 0