summaryrefslogtreecommitdiff
path: root/test/test_optim.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_optim.py')
-rw-r--r--test/test_optim.py63
1 files changed, 56 insertions, 7 deletions
diff --git a/test/test_optim.py b/test/test_optim.py
index 67e9d13..a43e704 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -120,31 +120,33 @@ def test_generalized_conditional_gradient(nx):
Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True)
Gb = nx.to_numpy(Gb)
- np.testing.assert_allclose(Gb, G)
+ np.testing.assert_allclose(Gb, G, atol=1e-12)
np.testing.assert_allclose(a, Gb.sum(1), atol=1e-05)
np.testing.assert_allclose(b, Gb.sum(0), atol=1e-05)
def test_solve_1d_linesearch_quad_funct():
- np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5)
- np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0)
- np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1), 0.5)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5), 0)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5), 1)
def test_line_search_armijo(nx):
xk = np.array([[0.25, 0.25], [0.25, 0.25]])
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
- old_fval = -123
+ old_fval = -123.
xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk)
+ def f(x):
+ return 1.
# Should not throw an exception and return 0. for alpha
alpha, a, b = ot.optim.line_search_armijo(
- lambda x: 1, xkb, pkb, gfkb, old_fval
+ f, xkb, pkb, gfkb, old_fval
)
alpha_np, anp, bnp = ot.optim.line_search_armijo(
- lambda x: 1, xk, pk, gfk, old_fval
+ f, xk, pk, gfk, old_fval
)
assert a == anp
assert b == bnp
@@ -182,3 +184,50 @@ def test_line_search_armijo(nx):
old_fval = f(xk)
alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
np.testing.assert_allclose(alpha, 0.1)
+
+
+def test_line_search_armijo_dtype_device(nx):
+ for tp in nx.__type_list__:
+ def f(x):
+ return nx.sum((x - 5.0) ** 2)
+
+ def grad(x):
+ return 2 * (x - 5.0)
+
+ xk = np.array([[[-5.0, -5.0]]])
+ pk = np.array([[[100.0, 100.0]]])
+ xkb, pkb = nx.from_numpy(xk, pk, type_as=tp)
+ gfkb = grad(xkb)
+ old_fval = f(xkb)
+
+ # chech the case where the optimum is on the direction
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval)
+ alpha = nx.to_numpy(alpha)
+ np.testing.assert_allclose(alpha, 0.1)
+ nx.assert_same_dtype_device(old_fval, fval)
+
+ # check the case where the direction is not far enough
+ pk = np.array([[[3.0, 3.0]]])
+ pkb = nx.from_numpy(pk, type_as=tp)
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval, alpha0=1.0)
+ alpha = nx.to_numpy(alpha)
+ np.testing.assert_allclose(alpha, 1.0)
+ nx.assert_same_dtype_device(old_fval, fval)
+
+ # check the case where checking the wrong direction
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, -pkb, gfkb, old_fval)
+ alpha = nx.to_numpy(alpha)
+
+ assert alpha <= 0
+ nx.assert_same_dtype_device(old_fval, fval)
+
+ # check the case where the point is not a vector
+ xkb = nx.from_numpy(np.array(-5.0), type_as=tp)
+ pkb = nx.from_numpy(np.array(100), type_as=tp)
+ gfkb = grad(xkb)
+ old_fval = f(xkb)
+ alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval)
+ alpha = nx.to_numpy(alpha)
+
+ np.testing.assert_allclose(alpha, 0.1)
+ nx.assert_same_dtype_device(old_fval, fval)