summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py496
1 files changed, 288 insertions, 208 deletions
diff --git a/ot/optim.py b/ot/optim.py
index 5a1d605..58e5596 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
"""
-Generic solvers for regularized OT
+Generic solvers for regularized OT or its semi-relaxed version.
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
-#
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
# License: MIT License
import numpy as np
@@ -27,7 +27,7 @@ with warnings.catch_warnings():
def line_search_armijo(
f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
- alpha0=0.99, alpha_min=None, alpha_max=None
+ alpha0=0.99, alpha_min=None, alpha_max=None, nx=None, **kwargs
):
r"""
Armijo linesearch function that works with matrices
@@ -35,6 +35,9 @@ def line_search_armijo(
Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the
armijo conditions.
+ .. note:: If the loss function f returns a float (resp. a 1d array) then
+ the returned alpha and fa are float (resp. 1d arrays).
+
Parameters
----------
f : callable
@@ -45,7 +48,7 @@ def line_search_armijo(
descent direction
gfk : array-like
gradient of `f` at :math:`x_k`
- old_fval : float
+ old_fval : float or 1d array
loss value at :math:`x_k`
args : tuple, optional
arguments given to `f`
@@ -57,138 +60,97 @@ def line_search_armijo(
minimum value for alpha
alpha_max : float, optional
maximum value for alpha
-
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
Returns
-------
- alpha : float
+ alpha : float or 1d array
step that satisfy armijo conditions
fc : int
nb of function call
- fa : float
+ fa : float or 1d array
loss value at step alpha
"""
-
- xk, pk, gfk = list_to_array(xk, pk, gfk)
- nx = get_backend(xk, pk)
+ if nx is None:
+ xk, pk, gfk = list_to_array(xk, pk, gfk)
+ xk0, pk0 = xk, pk
+ nx = get_backend(xk0, pk0)
+ else:
+ xk0, pk0 = xk, pk
if len(xk.shape) == 0:
xk = nx.reshape(xk, (-1,))
+ xk = nx.to_numpy(xk)
+ pk = nx.to_numpy(pk)
+ gfk = nx.to_numpy(gfk)
+
fc = [0]
def phi(alpha1):
+ # The callable function operates on nx backend
fc[0] += 1
- return f(xk + alpha1 * pk, *args)
+ alpha10 = nx.from_numpy(alpha1)
+ fval = f(xk0 + alpha10 * pk0, *args)
+ if type(fval) is float:
+ # prevent bug from nx.to_numpy that can look for .cpu or .gpu
+ return fval
+ else:
+ return nx.to_numpy(fval)
if old_fval is None:
phi0 = phi(0.)
- else:
+ elif type(old_fval) is float:
+ # prevent bug from nx.to_numpy that can look for .cpu or .gpu
phi0 = old_fval
+ else:
+ phi0 = nx.to_numpy(old_fval)
- derphi0 = nx.sum(pk * gfk) # Quickfix for matrices
+ derphi0 = np.sum(pk * gfk) # Quickfix for matrices
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
if alpha is None:
- return 0., fc[0], phi0
+ return 0., fc[0], nx.from_numpy(phi0, type_as=xk0)
else:
if alpha_min is not None or alpha_max is not None:
alpha = np.clip(alpha, alpha_min, alpha_max)
- return float(alpha), fc[0], phi1
+ return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0)
-def solve_linesearch(
- cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None,
- reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None
-):
- """
- Solve the linesearch in the FW iterations
-
- Parameters
- ----------
- cost : method
- Cost in the FW for the linesearch
- G : array-like, shape(ns,nt)
- The transport map at a given iteration of the FW
- deltaG : array-like (ns,nt)
- Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
- Mi : array-like (ns,nt)
- Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
- f_val : float
- Value of the cost at `G`
- armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
- C1 : array-like (ns,ns), optional
- Structure matrix in the source domain. Only used and necessary when armijo=False
- C2 : array-like (nt,nt), optional
- Structure matrix in the target domain. Only used and necessary when armijo=False
- reg : float, optional
- Regularization parameter. Only used and necessary when armijo=False
- Gc : array-like (ns,nt)
- Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False
- constC : array-like (ns,nt)
- Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
- M : array-like (ns,nt), optional
- Cost matrix between the features. Only used and necessary when armijo=False
- alpha_min : float, optional
- Minimum value for alpha
- alpha_max : float, optional
- Maximum value for alpha
-
- Returns
- -------
- alpha : float
- The optimal step size of the FW
- fc : int
- nb of function call. Useless here
- f_val : float
- The value of the cost for the next iteration
+def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None,
+ numItermax=200, stopThr=1e-9,
+ stopThr2=1e-9, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the general regularized OT problem or its semi-relaxed version with
+ conditional gradient or generalized conditional gradient depending on the
+ provided linear program solver.
+ The function solves the following optimization problem if set as a conditional gradient:
- .. _references-solve-linesearch:
- References
- ----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
- "Optimal Transport for structured data with application on graphs"
- International Conference on Machine Learning (ICML). 2019.
- """
- if armijo:
- alpha, fc, f_val = line_search_armijo(
- cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max
- )
- else: # requires symetric matrices
- G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M)
- if isinstance(M, int) or isinstance(M, float):
- nx = get_backend(G, deltaG, C1, C2, constC)
- else:
- nx = get_backend(G, deltaG, C1, C2, constC, M)
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_1} \cdot f(\gamma)
- dot = nx.dot(nx.dot(C1, deltaG), C2)
- a = -2 * reg * nx.sum(dot * deltaG)
- b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG))
- c = cost(G)
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- alpha = solve_1d_linesearch_quad(a, b, c)
- if alpha_min is not None or alpha_max is not None:
- alpha = np.clip(alpha, alpha_min, alpha_max)
- fc = None
- f_val = cost(G + alpha * deltaG)
+ \gamma^T \mathbf{1} &= \mathbf{b} (optional constraint)
- return alpha, fc, f_val
+ \gamma &\geq 0
+ where :
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
-def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
- stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
- r"""
- Solve the general regularized OT problem with conditional gradient
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
- The function solves the following optimization problem:
+ The function solves the following optimization problem if set a generalized conditional gradient:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
- \mathrm{reg} \cdot f(\gamma)
+ \mathrm{reg_1}\cdot f(\gamma) + \mathrm{reg_2}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
@@ -197,29 +159,39 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
\gamma &\geq 0
where :
- - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- - :math:`f` is the regularization term (and `df` is its gradient)
- - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
-
- The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>`
Parameters
----------
a : array-like, shape (ns,)
samples weights in the source domain
b : array-like, shape (nt,)
- samples in the target domain
+ samples weights in the target domain
M : array-like, shape (ns, nt)
loss matrix
- reg : float
+ f : function
+ Regularization function taking a transportation matrix as argument
+ df: function
+ Gradient of the regularization function taking a transportation matrix as argument
+ reg1 : float
Regularization term >0
+ reg2 : float,
+ Entropic Regularization term >0. Ignored if set to None.
+ lp_solver: function,
+ linear program solver for direction finding of the (generalized) conditional gradient.
+ If set to emd will solve the general regularized OT problem using cg.
+ If set to lp_semi_relaxed_OT will solve the general regularized semi-relaxed OT problem using cg.
+ If set to sinkhorn will solve the general regularized OT problem using generalized cg.
+ line_search: function,
+ Function to find the optimal step. Currently used instances are:
+ line_search_armijo (generic solver). solve_gromov_linesearch for (F)GW problem.
+ solve_semirelaxed_gromov_linesearch for sr(F)GW problem. gcg_linesearch for the Generalized cg.
G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
- numItermaxEmd : int, optional
- Max number of iterations for emd
stopThr : float, optional
Stop threshold on the relative variation (>0)
stopThr2 : float, optional
@@ -240,16 +212,20 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
.. _references-cg:
+ .. _references_gcg:
References
----------
.. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
+ .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
+
See Also
--------
ot.lp.emd : Unregularized optimal ransport
ot.bregman.sinkhorn : Entropic regularized optimal transport
-
"""
a, b, M, G0 = list_to_array(a, b, M, G0)
if isinstance(M, int) or isinstance(M, float):
@@ -265,42 +241,45 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
if G0 is None:
G = nx.outer(a, b)
else:
- G = G0
-
- def cost(G):
- return nx.sum(M * G) + reg * f(G)
+ # to not change G0 in place.
+ G = nx.copy(G0)
- f_val = cost(G)
+ if reg2 is None:
+ def cost(G):
+ return nx.sum(M * G) + reg1 * f(G)
+ else:
+ def cost(G):
+ return nx.sum(M * G) + reg1 * f(G) + reg2 * nx.sum(G * nx.log(G))
+ cost_G = cost(G)
if log:
- log['loss'].append(f_val)
+ log['loss'].append(cost_G)
it = 0
if verbose:
print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, 0, 0))
while loop:
it += 1
- old_fval = f_val
-
+ old_cost_G = cost_G
# problem linearization
- Mi = M + reg * df(G)
+ Mi = M + reg1 * df(G)
+
+ if not (reg2 is None):
+ Mi = Mi + reg2 * (1 + nx.log(G))
# set M positive
- Mi += nx.min(Mi)
+ Mi = Mi + nx.min(Mi)
# solve linear program
- Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True)
+ Gc, innerlog_ = lp_solver(a, b, Mi, **kwargs)
+ # line search
deltaG = Gc - G
- # line search
- alpha, fc, f_val = solve_linesearch(
- cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,
- alpha_min=0., alpha_max=1., **kwargs
- )
+ alpha, fc, cost_G = line_search(cost, G, deltaG, Mi, cost_G, **kwargs)
G = G + alpha * deltaG
@@ -308,29 +287,197 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
if it >= numItermax:
loop = 0
- abs_delta_fval = abs(f_val - old_fval)
- relative_delta_fval = abs_delta_fval / abs(f_val)
- if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
+ abs_delta_cost_G = abs(cost_G - old_cost_G)
+ relative_delta_cost_G = abs_delta_cost_G / abs(cost_G)
+ if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2:
loop = 0
if log:
- log['loss'].append(f_val)
+ log['loss'].append(cost_G)
if verbose:
if it % 20 == 0:
print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, relative_delta_cost_G, abs_delta_cost_G))
if log:
- log.update(logemd)
+ log.update(innerlog_)
return G, log
else:
return G
+def cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo,
+ numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9,
+ verbose=False, log=False, **kwargs):
+ r"""
+ Solve the general regularized OT problem with conditional gradient
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot f(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
+
+ \gamma^T \mathbf{1} &= \mathbf{b}
+
+ \gamma &\geq 0
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
+
+
+ Parameters
+ ----------
+ a : array-like, shape (ns,)
+ samples weights in the source domain
+ b : array-like, shape (nt,)
+ samples in the target domain
+ M : array-like, shape (ns, nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ G0 : array-like, shape (ns,nt), optional
+ initial guess (default is indep joint density)
+ line_search: function,
+ Function to find the optimal step.
+ Default is line_search_armijo.
+ numItermax : int, optional
+ Max number of iterations
+ numItermaxEmd : int, optional
+ Max number of iterations for emd
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshold on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ **kwargs : dict
+ Parameters for linesearch
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-cg:
+ References
+ ----------
+
+ .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized optimal ransport
+ ot.bregman.sinkhorn : Entropic regularized optimal transport
+
+ """
+
+ def lp_solver(a, b, M, **kwargs):
+ return emd(a, b, M, numItermaxEmd, log=True)
+
+ return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0,
+ numItermax=numItermax, stopThr=stopThr,
+ stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
+
+
+def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo,
+ numItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the general regularized and semi-relaxed OT problem with conditional gradient
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot f(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
+
+ \gamma &\geq 0
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
+
+
+ Parameters
+ ----------
+ a : array-like, shape (ns,)
+ samples weights in the source domain
+ b : array-like, shape (nt,)
+ currently estimated samples weights in the target domain
+ M : array-like, shape (ns, nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ G0 : array-like, shape (ns,nt), optional
+ initial guess (default is indep joint density)
+ line_search: function,
+ Function to find the optimal step.
+ Default is the armijo line-search.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshold on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ **kwargs : dict
+ Parameters for linesearch
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-cg:
+ References
+ ----------
+
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2021.
+
+ """
+
+ nx = get_backend(a, b)
+
+ def lp_solver(a, b, Mi, **kwargs):
+ # get minimum by rows as binary mask
+ Gc = nx.ones(1, type_as=a) * (Mi == nx.reshape(nx.min(Mi, axis=1), (-1, 1)))
+ Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1))
+ # return by default an empty inner_log
+ return Gc, {}
+
+ return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0,
+ numItermax=numItermax, stopThr=stopThr,
+ stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
+
+
def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
- numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False):
+ numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
r"""
Solve the general regularized OT problem with the generalized conditional gradient
@@ -403,81 +550,18 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
ot.optim.cg : conditional gradient
"""
- a, b, M, G0 = list_to_array(a, b, M, G0)
- nx = get_backend(a, b, M)
-
- loop = 1
-
- if log:
- log = {'loss': []}
-
- if G0 is None:
- G = nx.outer(a, b)
- else:
- G = G0
-
- def cost(G):
- return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G)
-
- f_val = cost(G)
- if log:
- log['loss'].append(f_val)
-
- it = 0
-
- if verbose:
- print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
- 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
-
- while loop:
-
- it += 1
- old_fval = f_val
-
- # problem linearization
- Mi = M + reg2 * df(G)
-
- # solve linear program with Sinkhorn
- # Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax)
- Gc = sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax)
- deltaG = Gc - G
-
- # line search
- dcost = Mi + reg1 * (1 + nx.log(G)) # ??
- alpha, fc, f_val = line_search_armijo(
- cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1.
- )
+ def lp_solver(a, b, Mi, **kwargs):
+ return sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax, log=True, **kwargs)
- G = G + alpha * deltaG
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs)
- # test convergence
- if it >= numItermax:
- loop = 0
+ return generic_conditional_gradient(a, b, M, f, df, reg2, reg1, lp_solver, line_search, G0=G0,
+ numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
- abs_delta_fval = abs(f_val - old_fval)
- relative_delta_fval = abs_delta_fval / abs(f_val)
- if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
- loop = 0
-
- if log:
- log['loss'].append(f_val)
-
- if verbose:
- if it % 20 == 0:
- print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
- 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
-
- if log:
- return G, log
- else:
- return G
-
-
-def solve_1d_linesearch_quad(a, b, c):
+def solve_1d_linesearch_quad(a, b):
r"""
For any convex or non-convex 1d quadratic function `f`, solve the following problem:
@@ -487,7 +571,7 @@ def solve_1d_linesearch_quad(a, b, c):
Parameters
----------
- a,b,c : float
+ a,b : float or tensors (1,)
The coefficients of the quadratic function
Returns
@@ -495,15 +579,11 @@ def solve_1d_linesearch_quad(a, b, c):
x : float
The optimal value which leads to the minimal cost
"""
- f0 = c
- df0 = b
- f1 = a + f0 + df0
-
if a > 0: # convex
- minimum = min(1, max(0, np.divide(-b, 2.0 * a)))
+ minimum = min(1., max(0., -b / (2.0 * a)))
return minimum
else: # non convex
- if f0 > f1:
- return 1
+ if a + b < 0:
+ return 1.
else:
- return 0
+ return 0.