summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-09-28 16:34:28 +0200
committerGitHub <noreply@github.com>2021-09-28 16:34:28 +0200
commit7dde9e8e4b6aae756e103d49198caaa4f24150e3 (patch)
tree3961588cfe35d371ebf399bd6c138c2a1bcb1697
parente0ba31ce39a7d9e65e50ea970a574b3db54e4207 (diff)
[MRG] Regularized OT (optim.cg) bug solve (#286)
* Line search stops when derphi is 0 instead of bugging out like in some instances * pep8 compliance * Tests
-rw-r--r--ot/optim.py10
-rw-r--r--test/test_da.py8
-rw-r--r--test/test_optim.py25
3 files changed, 39 insertions, 4 deletions
diff --git a/ot/optim.py b/ot/optim.py
index abe9e6a..0359343 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -178,9 +178,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
numItermaxEmd : int, optional
Max number of iterations for emd
stopThr : float, optional
- Stop threshol on the relative variation (>0)
+ Stop threshold on the relative variation (>0)
stopThr2 : float, optional
- Stop threshol on the absolute variation (>0)
+ Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -249,6 +249,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
# line search
alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
+ if alpha is None:
+ alpha = 0.0
G = G + alpha * deltaG
@@ -320,9 +322,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
numInnerItermax : int, optional
Max number of iterations of Sinkhorn
stopThr : float, optional
- Stop threshol on the relative variation (>0)
+ Stop threshold on the relative variation (>0)
stopThr2 : float, optional
- Stop threshol on the absolute variation (>0)
+ Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
diff --git a/test/test_da.py b/test/test_da.py
index 44bb2e9..9f2bb50 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -565,6 +565,14 @@ def test_mapping_transport_class():
otda.fit(Xs=Xs, Xt=Xt)
assert len(otda.log_.keys()) != 0
+ # check that it does not crash when derphi is very close to 0
+ np.random.seed(39)
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+ otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
+ otda.fit(Xs=Xs, Xt=Xt)
+ np.random.seed(None)
+
def test_linear_mapping():
ns = 150
diff --git a/test/test_optim.py b/test/test_optim.py
index fd194c2..94995d5 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -114,3 +114,28 @@ def test_line_search_armijo():
# Should not throw an exception and return None for alpha
alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval)
assert alpha is None
+
+ # check line search armijo
+ def f(x):
+ return np.sum((x - 5.0) ** 2)
+
+ def grad(x):
+ return 2 * (x - 5.0)
+
+ xk = np.array([[[-5.0, -5.0]]])
+ pk = np.array([[[100.0, 100.0]]])
+ gfk = grad(xk)
+ old_fval = f(xk)
+
+ # chech the case where the optimum is on the direction
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
+ np.testing.assert_allclose(alpha, 0.1)
+
+ # check the case where the direction is not far enough
+ pk = np.array([[[3.0, 3.0]]])
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0)
+ np.testing.assert_allclose(alpha, 1.0)
+
+ # check the case where the checking the wrong direction
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval)
+ assert alpha <= 0