diff options
author | Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com> | 2023-03-16 08:05:54 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-16 08:05:54 +0100 |
commit | 583501652517c4f1dbd8572e9f942551a9e54a1f (patch) | |
tree | fadb96f888924b2d1bef01b78486e97a88ebcd42 /ot/optim.py | |
parent | 8f56effe7320991ebdc6457a2cf1d3b6648a09d1 (diff) |
[MRG] fix bugs of gw_entropic and armijo to run on gpu (#446)
* maj gw/ srgw/ generic cg solver
* correct pep8 on current state
* fix bug previous tests
* fix pep8
* fix bug srGW constC in loss and gradient
* fix doc html
* fix doc html
* start updating test_optim.py
* update tests gromov and optim - plus fix gromov dependencies
* add symmetry feature to entropic gw
* add symmetry feature to entropic gw
* add exemple for sr(F)GW matchings
* small stuff
* remove (reg,M) from line-search/ complete srgw tests with backend
* remove backend repetitions / rename fG to costG/ fix innerlog to True
* fix pep8
* take comments into account / new nx parameters still to test
* factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions
* split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions
* manual documentaion gromov
* remove circular autosummary
* trying stuff
* debug documentation
* alphabetic ordering of module
* merge into branch
* add note in entropic gw solvers
* fix exemples/gromov doc
* add fixed issue to releases.md
* fix bugs of gw_entropic and armijo to run on gpu
* add pr to releases.md
* fix pep8
* fix call to backend in line_search_armijo
* correct docstring generic_conditional_gradient
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/optim.py')
-rw-r--r-- | ot/optim.py | 38 |
1 files changed, 29 insertions, 9 deletions
diff --git a/ot/optim.py b/ot/optim.py index 201f898..58e5596 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -35,6 +35,9 @@ def line_search_armijo( Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the armijo conditions. + .. note:: If the loss function f returns a float (resp. a 1d array) then + the returned alpha and fa are float (resp. 1d arrays). + Parameters ---------- f : callable @@ -45,7 +48,7 @@ def line_search_armijo( descent direction gfk : array-like gradient of `f` at :math:`x_k` - old_fval : float + old_fval : float or 1d array loss value at :math:`x_k` args : tuple, optional arguments given to `f` @@ -61,42 +64,59 @@ def line_search_armijo( If let to its default value None, a backend test will be conducted. Returns ------- - alpha : float + alpha : float or 1d array step that satisfy armijo conditions fc : int nb of function call - fa : float + fa : float or 1d array loss value at step alpha """ if nx is None: xk, pk, gfk = list_to_array(xk, pk, gfk) - nx = get_backend(xk, pk) + xk0, pk0 = xk, pk + nx = get_backend(xk0, pk0) + else: + xk0, pk0 = xk, pk if len(xk.shape) == 0: xk = nx.reshape(xk, (-1,)) + xk = nx.to_numpy(xk) + pk = nx.to_numpy(pk) + gfk = nx.to_numpy(gfk) + fc = [0] def phi(alpha1): + # The callable function operates on nx backend fc[0] += 1 - return f(xk + alpha1 * pk, *args) + alpha10 = nx.from_numpy(alpha1) + fval = f(xk0 + alpha10 * pk0, *args) + if type(fval) is float: + # prevent bug from nx.to_numpy that can look for .cpu or .gpu + return fval + else: + return nx.to_numpy(fval) if old_fval is None: phi0 = phi(0.) - else: + elif type(old_fval) is float: + # prevent bug from nx.to_numpy that can look for .cpu or .gpu phi0 = old_fval + else: + phi0 = nx.to_numpy(old_fval) - derphi0 = nx.sum(pk * gfk) # Quickfix for matrices + derphi0 = np.sum(pk * gfk) # Quickfix for matrices alpha, phi1 = scalar_search_armijo( phi, phi0, derphi0, c1=c1, alpha0=alpha0) if alpha is None: - return 0., fc[0], phi0 + return 0., fc[0], nx.from_numpy(phi0, type_as=xk0) else: if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) - return float(alpha), fc[0], phi1 + return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0) def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None, |