summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPanayiotis Panayiotou <p.panayiotou2@gmail.com>2020-08-24 15:40:05 +0300
committerGitHub <noreply@github.com>2020-08-24 14:40:05 +0200
commit24a7a0439e631e90ff84ce84d0a78bc22846cf71 (patch)
tree0514f646d997723ec73e2af967d4a284f84abf10
parent679ed3120da21d620b7cd9a838e073c817653864 (diff)
Check if alpha is not None when restricting it to be at most 1 (#199)
* Check if alpha is not None when restricting it to be at most 1 * Write check more clearly * Add no regression test for line search armijo returning None for alpha Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
-rw-r--r--ot/optim.py5
-rw-r--r--test/test_optim.py10
2 files changed, 14 insertions, 1 deletions
diff --git a/ot/optim.py b/ot/optim.py
index e7e6e65..1902907 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -69,7 +69,10 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
- return min(1, alpha), fc[0], phi1
+ # scalar_search_armijo can return alpha > 1
+ if alpha is not None:
+ alpha = min(1, alpha)
+ return alpha, fc[0], phi1
def solve_linesearch(cost, G, deltaG, Mi, f_val,
diff --git a/test/test_optim.py b/test/test_optim.py
index 87b0268..48de38a 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -104,3 +104,13 @@ def test_solve_1d_linesearch_quad_funct():
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5)
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0)
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)
+
+
+def test_line_search_armijo():
+ xk = np.array([[0.25, 0.25], [0.25, 0.25]])
+ pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
+ gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
+ old_fval = -123
+ # Should not throw an exception and return None for alpha
+ alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval)
+ assert alpha is None