From 7dde9e8e4b6aae756e103d49198caaa4f24150e3 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Tue, 28 Sep 2021 16:34:28 +0200 Subject: [MRG] Regularized OT (optim.cg) bug solve (#286) * Line search stops when derphi is 0 instead of bugging out like in some instances * pep8 compliance * Tests --- ot/optim.py | 10 ++++++---- test/test_da.py | 8 ++++++++ test/test_optim.py | 25 +++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/ot/optim.py b/ot/optim.py index abe9e6a..0359343 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -178,9 +178,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, numItermaxEmd : int, optional Max number of iterations for emd stopThr : float, optional - Stop threshol on the relative variation (>0) + Stop threshold on the relative variation (>0) stopThr2 : float, optional - Stop threshol on the absolute variation (>0) + Stop threshold on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -249,6 +249,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, # line search alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) + if alpha is None: + alpha = 0.0 G = G + alpha * deltaG @@ -320,9 +322,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax : int, optional Max number of iterations of Sinkhorn stopThr : float, optional - Stop threshol on the relative variation (>0) + Stop threshold on the relative variation (>0) stopThr2 : float, optional - Stop threshol on the absolute variation (>0) + Stop threshold on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional diff --git a/test/test_da.py b/test/test_da.py index 44bb2e9..9f2bb50 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -565,6 +565,14 @@ def test_mapping_transport_class(): otda.fit(Xs=Xs, Xt=Xt) assert len(otda.log_.keys()) != 0 + # check that it does not crash when derphi is very close to 0 + np.random.seed(39) + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + otda = ot.da.MappingTransport(kernel="gaussian", bias=False) + otda.fit(Xs=Xs, Xt=Xt) + np.random.seed(None) + def test_linear_mapping(): ns = 150 diff --git a/test/test_optim.py b/test/test_optim.py index fd194c2..94995d5 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -114,3 +114,28 @@ def test_line_search_armijo(): # Should not throw an exception and return None for alpha alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval) assert alpha is None + + # check line search armijo + def f(x): + return np.sum((x - 5.0) ** 2) + + def grad(x): + return 2 * (x - 5.0) + + xk = np.array([[[-5.0, -5.0]]]) + pk = np.array([[[100.0, 100.0]]]) + gfk = grad(xk) + old_fval = f(xk) + + # chech the case where the optimum is on the direction + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) + np.testing.assert_allclose(alpha, 0.1) + + # check the case where the direction is not far enough + pk = np.array([[[3.0, 3.0]]]) + alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0) + np.testing.assert_allclose(alpha, 1.0) + + # check the case where the checking the wrong direction + alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval) + assert alpha <= 0 -- cgit v1.2.3