summaryrefslogtreecommitdiff
path: root/test/test_optim.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-09-28 16:34:28 +0200
committerGitHub <noreply@github.com>2021-09-28 16:34:28 +0200
commit7dde9e8e4b6aae756e103d49198caaa4f24150e3 (patch)
tree3961588cfe35d371ebf399bd6c138c2a1bcb1697 /test/test_optim.py
parente0ba31ce39a7d9e65e50ea970a574b3db54e4207 (diff)
[MRG] Regularized OT (optim.cg) bug solve (#286)
* Line search stops when derphi is 0 instead of bugging out like in some instances * pep8 compliance * Tests
Diffstat (limited to 'test/test_optim.py')
-rw-r--r--test/test_optim.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/test/test_optim.py b/test/test_optim.py
index fd194c2..94995d5 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -114,3 +114,28 @@ def test_line_search_armijo():
# Should not throw an exception and return None for alpha
alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval)
assert alpha is None
+
+ # check line search armijo
+ def f(x):
+ return np.sum((x - 5.0) ** 2)
+
+ def grad(x):
+ return 2 * (x - 5.0)
+
+ xk = np.array([[[-5.0, -5.0]]])
+ pk = np.array([[[100.0, 100.0]]])
+ gfk = grad(xk)
+ old_fval = f(xk)
+
+ # chech the case where the optimum is on the direction
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
+ np.testing.assert_allclose(alpha, 0.1)
+
+ # check the case where the direction is not far enough
+ pk = np.array([[[3.0, 3.0]]])
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0)
+ np.testing.assert_allclose(alpha, 1.0)
+
+ # check the case where the checking the wrong direction
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval)
+ assert alpha <= 0