diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-09-28 16:34:28 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-28 16:34:28 +0200 |
commit | 7dde9e8e4b6aae756e103d49198caaa4f24150e3 (patch) | |
tree | 3961588cfe35d371ebf399bd6c138c2a1bcb1697 /test/test_optim.py | |
parent | e0ba31ce39a7d9e65e50ea970a574b3db54e4207 (diff) |
[MRG] Regularized OT (optim.cg) bug solve (#286)
* Line search stops when derphi is 0 instead of bugging out like in some instances
* pep8 compliance
* Tests
Diffstat (limited to 'test/test_optim.py')
-rw-r--r-- | test/test_optim.py | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/test/test_optim.py b/test/test_optim.py index fd194c2..94995d5 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -114,3 +114,28 @@ def test_line_search_armijo(): # Should not throw an exception and return None for alpha alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval) assert alpha is None + + # check line search armijo + def f(x): + return np.sum((x - 5.0) ** 2) + + def grad(x): + return 2 * (x - 5.0) + + xk = np.array([[[-5.0, -5.0]]]) + pk = np.array([[[100.0, 100.0]]]) + gfk = grad(xk) + old_fval = f(xk) + + # chech the case where the optimum is on the direction + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) + np.testing.assert_allclose(alpha, 0.1) + + # check the case where the direction is not far enough + pk = np.array([[[3.0, 3.0]]]) + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0) + np.testing.assert_allclose(alpha, 1.0) + + # check the case where the checking the wrong direction + alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval) + assert alpha <= 0 |