summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
authorCédric Vincent-Cuaz <cedvincentcuaz@gmail.com>2023-03-16 08:05:54 +0100
committerGitHub <noreply@github.com>2023-03-16 08:05:54 +0100
commit583501652517c4f1dbd8572e9f942551a9e54a1f (patch)
treefadb96f888924b2d1bef01b78486e97a88ebcd42 /ot/optim.py
parent8f56effe7320991ebdc6457a2cf1d3b6648a09d1 (diff)
[MRG] fix bugs of gw_entropic and armijo to run on gpu (#446)
* maj gw/ srgw/ generic cg solver * correct pep8 on current state * fix bug previous tests * fix pep8 * fix bug srGW constC in loss and gradient * fix doc html * fix doc html * start updating test_optim.py * update tests gromov and optim - plus fix gromov dependencies * add symmetry feature to entropic gw * add symmetry feature to entropic gw * add exemple for sr(F)GW matchings * small stuff * remove (reg,M) from line-search/ complete srgw tests with backend * remove backend repetitions / rename fG to costG/ fix innerlog to True * fix pep8 * take comments into account / new nx parameters still to test * factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions * split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions * manual documentaion gromov * remove circular autosummary * trying stuff * debug documentation * alphabetic ordering of module * merge into branch * add note in entropic gw solvers * fix exemples/gromov doc * add fixed issue to releases.md * fix bugs of gw_entropic and armijo to run on gpu * add pr to releases.md * fix pep8 * fix call to backend in line_search_armijo * correct docstring generic_conditional_gradient --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py38
1 files changed, 29 insertions, 9 deletions
diff --git a/ot/optim.py b/ot/optim.py
index 201f898..58e5596 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -35,6 +35,9 @@ def line_search_armijo(
Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the
armijo conditions.
+ .. note:: If the loss function f returns a float (resp. a 1d array) then
+ the returned alpha and fa are float (resp. 1d arrays).
+
Parameters
----------
f : callable
@@ -45,7 +48,7 @@ def line_search_armijo(
descent direction
gfk : array-like
gradient of `f` at :math:`x_k`
- old_fval : float
+ old_fval : float or 1d array
loss value at :math:`x_k`
args : tuple, optional
arguments given to `f`
@@ -61,42 +64,59 @@ def line_search_armijo(
If let to its default value None, a backend test will be conducted.
Returns
-------
- alpha : float
+ alpha : float or 1d array
step that satisfy armijo conditions
fc : int
nb of function call
- fa : float
+ fa : float or 1d array
loss value at step alpha
"""
if nx is None:
xk, pk, gfk = list_to_array(xk, pk, gfk)
- nx = get_backend(xk, pk)
+ xk0, pk0 = xk, pk
+ nx = get_backend(xk0, pk0)
+ else:
+ xk0, pk0 = xk, pk
if len(xk.shape) == 0:
xk = nx.reshape(xk, (-1,))
+ xk = nx.to_numpy(xk)
+ pk = nx.to_numpy(pk)
+ gfk = nx.to_numpy(gfk)
+
fc = [0]
def phi(alpha1):
+ # The callable function operates on nx backend
fc[0] += 1
- return f(xk + alpha1 * pk, *args)
+ alpha10 = nx.from_numpy(alpha1)
+ fval = f(xk0 + alpha10 * pk0, *args)
+ if type(fval) is float:
+ # prevent bug from nx.to_numpy that can look for .cpu or .gpu
+ return fval
+ else:
+ return nx.to_numpy(fval)
if old_fval is None:
phi0 = phi(0.)
- else:
+ elif type(old_fval) is float:
+ # prevent bug from nx.to_numpy that can look for .cpu or .gpu
phi0 = old_fval
+ else:
+ phi0 = nx.to_numpy(old_fval)
- derphi0 = nx.sum(pk * gfk) # Quickfix for matrices
+ derphi0 = np.sum(pk * gfk) # Quickfix for matrices
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
if alpha is None:
- return 0., fc[0], phi0
+ return 0., fc[0], nx.from_numpy(phi0, type_as=xk0)
else:
if alpha_min is not None or alpha_max is not None:
alpha = np.clip(alpha, alpha_min, alpha_max)
- return float(alpha), fc[0], phi1
+ return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0)
def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None,