diff options
Diffstat (limited to 'test/test_optim.py')
-rw-r--r-- | test/test_optim.py | 55 |
1 files changed, 52 insertions, 3 deletions
diff --git a/test/test_optim.py b/test/test_optim.py index 129fe22..a43e704 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -135,16 +135,18 @@ def test_line_search_armijo(nx): xk = np.array([[0.25, 0.25], [0.25, 0.25]]) pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) - old_fval = -123 + old_fval = -123. xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk) + def f(x): + return 1. # Should not throw an exception and return 0. for alpha alpha, a, b = ot.optim.line_search_armijo( - lambda x: 1, xkb, pkb, gfkb, old_fval + f, xkb, pkb, gfkb, old_fval ) alpha_np, anp, bnp = ot.optim.line_search_armijo( - lambda x: 1, xk, pk, gfk, old_fval + f, xk, pk, gfk, old_fval ) assert a == anp assert b == bnp @@ -182,3 +184,50 @@ def test_line_search_armijo(nx): old_fval = f(xk) alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) np.testing.assert_allclose(alpha, 0.1) + + +def test_line_search_armijo_dtype_device(nx): + for tp in nx.__type_list__: + def f(x): + return nx.sum((x - 5.0) ** 2) + + def grad(x): + return 2 * (x - 5.0) + + xk = np.array([[[-5.0, -5.0]]]) + pk = np.array([[[100.0, 100.0]]]) + xkb, pkb = nx.from_numpy(xk, pk, type_as=tp) + gfkb = grad(xkb) + old_fval = f(xkb) + + # chech the case where the optimum is on the direction + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval) + alpha = nx.to_numpy(alpha) + np.testing.assert_allclose(alpha, 0.1) + nx.assert_same_dtype_device(old_fval, fval) + + # check the case where the direction is not far enough + pk = np.array([[[3.0, 3.0]]]) + pkb = nx.from_numpy(pk, type_as=tp) + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval, alpha0=1.0) + alpha = nx.to_numpy(alpha) + np.testing.assert_allclose(alpha, 1.0) + nx.assert_same_dtype_device(old_fval, fval) + + # check the case where checking the wrong direction + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, -pkb, gfkb, old_fval) + alpha = nx.to_numpy(alpha) + + assert alpha <= 0 + nx.assert_same_dtype_device(old_fval, fval) + + # check the case where the point is not a vector + xkb = nx.from_numpy(np.array(-5.0), type_as=tp) + pkb = nx.from_numpy(np.array(100), type_as=tp) + gfkb = grad(xkb) + old_fval = f(xkb) + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval) + alpha = nx.to_numpy(alpha) + + np.testing.assert_allclose(alpha, 0.1) + nx.assert_same_dtype_device(old_fval, fval) |