summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2021-11-17 18:59:59 +0100
committerGitHub <noreply@github.com>2021-11-17 18:59:59 +0100
commit8b9e641645f7825255afacfd141bfbf52ba2857e (patch)
tree8cdecdd9ec1dba73ee9a1c43edd1ff06fc06465b
parente235b08c2ac86f75b6c1b8e96e503305aa0449e1 (diff)
[MRG] SinkhornL1L2 bug solve (#313)
* Now limiting alpha to a minimum value as well as a max value * Docs * typo
-rw-r--r--ot/optim.py39
1 files changed, 30 insertions, 9 deletions
diff --git a/ot/optim.py b/ot/optim.py
index cacec53..9b8a8f7 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -18,8 +18,10 @@ from .backend import get_backend
# The corresponding scipy function does not work for matrices
-def line_search_armijo(f, xk, pk, gfk, old_fval,
- args=(), c1=1e-4, alpha0=0.99):
+def line_search_armijo(
+ f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
+ alpha0=0.99, alpha_min=None, alpha_max=None
+):
r"""
Armijo linesearch function that works with matrices
@@ -44,6 +46,10 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
:math:`c_1` const in armijo rule (>0)
alpha0 : float, optional
initial step (>0)
+ alpha_min : float, optional
+ minimum value for alpha
+ alpha_max : float, optional
+ maximum value for alpha
Returns
-------
@@ -80,13 +86,15 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
if alpha is None:
return 0., fc[0], phi0
else:
- # scalar_search_armijo can return alpha > 1
- alpha = min(1, alpha)
+ if alpha_min is not None or alpha_max is not None:
+ alpha = np.clip(alpha, alpha_min, alpha_max)
return alpha, fc[0], phi1
-def solve_linesearch(cost, G, deltaG, Mi, f_val,
- armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
+def solve_linesearch(
+ cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None,
+ reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None
+):
"""
Solve the linesearch in the FW iterations
@@ -117,6 +125,10 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
M : array-like (ns,nt), optional
Cost matrix between the features. Only used and necessary when armijo=False
+ alpha_min : float, optional
+ Minimum value for alpha
+ alpha_max : float, optional
+ Maximum value for alpha
Returns
-------
@@ -136,7 +148,9 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
International Conference on Machine Learning (ICML). 2019.
"""
if armijo:
- alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
+ alpha, fc, f_val = line_search_armijo(
+ cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max
+ )
else: # requires symetric matrices
G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M)
if isinstance(M, int) or isinstance(M, float):
@@ -150,6 +164,8 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
c = cost(G)
alpha = solve_1d_linesearch_quad(a, b, c)
+ if alpha_min is not None or alpha_max is not None:
+ alpha = np.clip(alpha, alpha_min, alpha_max)
fc = None
f_val = cost(G + alpha * deltaG)
@@ -274,7 +290,10 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
deltaG = Gc - G
# line search
- alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
+ alpha, fc, f_val = solve_linesearch(
+ cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,
+ alpha_min=0., alpha_max=1., **kwargs
+ )
G = G + alpha * deltaG
@@ -420,7 +439,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
# line search
dcost = Mi + reg1 * (1 + nx.log(G)) # ??
- alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val)
+ alpha, fc, f_val = line_search_armijo(
+ cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1.
+ )
G = G + alpha * deltaG