diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/gromov/_bregman.py | 5 | ||||
-rw-r--r-- | ot/optim.py | 38 |
2 files changed, 30 insertions, 13 deletions
diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index 5b2f959..b0cccfb 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -11,9 +11,6 @@ Bregman projections solvers for entropic Gromov-Wasserstein # # License: MIT License -import numpy as np - - from ..bregman import sinkhorn from ..utils import dist, list_to_array, check_random_state from ..backend import get_backend @@ -109,7 +106,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, T = G0 constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx) if symmetric is None: - symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) if not symmetric: constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx) cpt = 0 diff --git a/ot/optim.py b/ot/optim.py index 201f898..58e5596 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -35,6 +35,9 @@ def line_search_armijo( Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the armijo conditions. + .. note:: If the loss function f returns a float (resp. a 1d array) then + the returned alpha and fa are float (resp. 1d arrays). + Parameters ---------- f : callable @@ -45,7 +48,7 @@ def line_search_armijo( descent direction gfk : array-like gradient of `f` at :math:`x_k` - old_fval : float + old_fval : float or 1d array loss value at :math:`x_k` args : tuple, optional arguments given to `f` @@ -61,42 +64,59 @@ def line_search_armijo( If let to its default value None, a backend test will be conducted. Returns ------- - alpha : float + alpha : float or 1d array step that satisfy armijo conditions fc : int nb of function call - fa : float + fa : float or 1d array loss value at step alpha """ if nx is None: xk, pk, gfk = list_to_array(xk, pk, gfk) - nx = get_backend(xk, pk) + xk0, pk0 = xk, pk + nx = get_backend(xk0, pk0) + else: + xk0, pk0 = xk, pk if len(xk.shape) == 0: xk = nx.reshape(xk, (-1,)) + xk = nx.to_numpy(xk) + pk = nx.to_numpy(pk) + gfk = nx.to_numpy(gfk) + fc = [0] def phi(alpha1): + # The callable function operates on nx backend fc[0] += 1 - return f(xk + alpha1 * pk, *args) + alpha10 = nx.from_numpy(alpha1) + fval = f(xk0 + alpha10 * pk0, *args) + if type(fval) is float: + # prevent bug from nx.to_numpy that can look for .cpu or .gpu + return fval + else: + return nx.to_numpy(fval) if old_fval is None: phi0 = phi(0.) - else: + elif type(old_fval) is float: + # prevent bug from nx.to_numpy that can look for .cpu or .gpu phi0 = old_fval + else: + phi0 = nx.to_numpy(old_fval) - derphi0 = nx.sum(pk * gfk) # Quickfix for matrices + derphi0 = np.sum(pk * gfk) # Quickfix for matrices alpha, phi1 = scalar_search_armijo( phi, phi0, derphi0, c1=c1, alpha0=alpha0) if alpha is None: - return 0., fc[0], phi0 + return 0., fc[0], nx.from_numpy(phi0, type_as=xk0) else: if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) - return float(alpha), fc[0], phi1 + return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0) def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None, |