summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_da.py8
-rw-r--r--test/test_optim.py25
2 files changed, 33 insertions, 0 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 44bb2e9..9f2bb50 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -565,6 +565,14 @@ def test_mapping_transport_class():
otda.fit(Xs=Xs, Xt=Xt)
assert len(otda.log_.keys()) != 0
+ # check that it does not crash when derphi is very close to 0
+ np.random.seed(39)
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+ otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
+ otda.fit(Xs=Xs, Xt=Xt)
+ np.random.seed(None)
+
def test_linear_mapping():
ns = 150
diff --git a/test/test_optim.py b/test/test_optim.py
index fd194c2..94995d5 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -114,3 +114,28 @@ def test_line_search_armijo():
# Should not throw an exception and return None for alpha
alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval)
assert alpha is None
+
+ # check line search armijo
+ def f(x):
+ return np.sum((x - 5.0) ** 2)
+
+ def grad(x):
+ return 2 * (x - 5.0)
+
+ xk = np.array([[[-5.0, -5.0]]])
+ pk = np.array([[[100.0, 100.0]]])
+ gfk = grad(xk)
+ old_fval = f(xk)
+
+ # chech the case where the optimum is on the direction
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
+ np.testing.assert_allclose(alpha, 0.1)
+
+ # check the case where the direction is not far enough
+ pk = np.array([[[3.0, 3.0]]])
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0)
+ np.testing.assert_allclose(alpha, 1.0)
+
+ # check the case where the checking the wrong direction
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval)
+ assert alpha <= 0