From 7dde9e8e4b6aae756e103d49198caaa4f24150e3 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Tue, 28 Sep 2021 16:34:28 +0200 Subject: [MRG] Regularized OT (optim.cg) bug solve (#286) * Line search stops when derphi is 0 instead of bugging out like in some instances * pep8 compliance * Tests --- test/test_da.py | 8 ++++++++ test/test_optim.py | 25 +++++++++++++++++++++++++ 2 files changed, 33 insertions(+) (limited to 'test') diff --git a/test/test_da.py b/test/test_da.py index 44bb2e9..9f2bb50 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -565,6 +565,14 @@ def test_mapping_transport_class(): otda.fit(Xs=Xs, Xt=Xt) assert len(otda.log_.keys()) != 0 + # check that it does not crash when derphi is very close to 0 + np.random.seed(39) + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + otda = ot.da.MappingTransport(kernel="gaussian", bias=False) + otda.fit(Xs=Xs, Xt=Xt) + np.random.seed(None) + def test_linear_mapping(): ns = 150 diff --git a/test/test_optim.py b/test/test_optim.py index fd194c2..94995d5 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -114,3 +114,28 @@ def test_line_search_armijo(): # Should not throw an exception and return None for alpha alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval) assert alpha is None + + # check line search armijo + def f(x): + return np.sum((x - 5.0) ** 2) + + def grad(x): + return 2 * (x - 5.0) + + xk = np.array([[[-5.0, -5.0]]]) + pk = np.array([[[100.0, 100.0]]]) + gfk = grad(xk) + old_fval = f(xk) + + # chech the case where the optimum is on the direction + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) + np.testing.assert_allclose(alpha, 0.1) + + # check the case where the direction is not far enough + pk = np.array([[[3.0, 3.0]]]) + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0) + np.testing.assert_allclose(alpha, 1.0) + + # check the case where the checking the wrong direction + alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval) + assert alpha <= 0 -- cgit v1.2.3