diff options
Diffstat (limited to 'ot/optim.py')
-rw-r--r-- | ot/optim.py | 496 |
1 files changed, 288 insertions, 208 deletions
diff --git a/ot/optim.py b/ot/optim.py index 5a1d605..58e5596 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- """ -Generic solvers for regularized OT +Generic solvers for regularized OT or its semi-relaxed version. """ # Author: Remi Flamary <remi.flamary@unice.fr> # Titouan Vayer <titouan.vayer@irisa.fr> -# +# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com> # License: MIT License import numpy as np @@ -27,7 +27,7 @@ with warnings.catch_warnings(): def line_search_armijo( f, xk, pk, gfk, old_fval, args=(), c1=1e-4, - alpha0=0.99, alpha_min=None, alpha_max=None + alpha0=0.99, alpha_min=None, alpha_max=None, nx=None, **kwargs ): r""" Armijo linesearch function that works with matrices @@ -35,6 +35,9 @@ def line_search_armijo( Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the armijo conditions. + .. note:: If the loss function f returns a float (resp. a 1d array) then + the returned alpha and fa are float (resp. 1d arrays). + Parameters ---------- f : callable @@ -45,7 +48,7 @@ def line_search_armijo( descent direction gfk : array-like gradient of `f` at :math:`x_k` - old_fval : float + old_fval : float or 1d array loss value at :math:`x_k` args : tuple, optional arguments given to `f` @@ -57,138 +60,97 @@ def line_search_armijo( minimum value for alpha alpha_max : float, optional maximum value for alpha - + nx : backend, optional + If let to its default value None, a backend test will be conducted. Returns ------- - alpha : float + alpha : float or 1d array step that satisfy armijo conditions fc : int nb of function call - fa : float + fa : float or 1d array loss value at step alpha """ - - xk, pk, gfk = list_to_array(xk, pk, gfk) - nx = get_backend(xk, pk) + if nx is None: + xk, pk, gfk = list_to_array(xk, pk, gfk) + xk0, pk0 = xk, pk + nx = get_backend(xk0, pk0) + else: + xk0, pk0 = xk, pk if len(xk.shape) == 0: xk = nx.reshape(xk, (-1,)) + xk = nx.to_numpy(xk) + pk = nx.to_numpy(pk) + gfk = nx.to_numpy(gfk) + fc = [0] def phi(alpha1): + # The callable function operates on nx backend fc[0] += 1 - return f(xk + alpha1 * pk, *args) + alpha10 = nx.from_numpy(alpha1) + fval = f(xk0 + alpha10 * pk0, *args) + if type(fval) is float: + # prevent bug from nx.to_numpy that can look for .cpu or .gpu + return fval + else: + return nx.to_numpy(fval) if old_fval is None: phi0 = phi(0.) - else: + elif type(old_fval) is float: + # prevent bug from nx.to_numpy that can look for .cpu or .gpu phi0 = old_fval + else: + phi0 = nx.to_numpy(old_fval) - derphi0 = nx.sum(pk * gfk) # Quickfix for matrices + derphi0 = np.sum(pk * gfk) # Quickfix for matrices alpha, phi1 = scalar_search_armijo( phi, phi0, derphi0, c1=c1, alpha0=alpha0) if alpha is None: - return 0., fc[0], phi0 + return 0., fc[0], nx.from_numpy(phi0, type_as=xk0) else: if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) - return float(alpha), fc[0], phi1 + return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0) -def solve_linesearch( - cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None, - reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None -): - """ - Solve the linesearch in the FW iterations - - Parameters - ---------- - cost : method - Cost in the FW for the linesearch - G : array-like, shape(ns,nt) - The transport map at a given iteration of the FW - deltaG : array-like (ns,nt) - Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration - Mi : array-like (ns,nt) - Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost - f_val : float - Value of the cost at `G` - armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. - C1 : array-like (ns,ns), optional - Structure matrix in the source domain. Only used and necessary when armijo=False - C2 : array-like (nt,nt), optional - Structure matrix in the target domain. Only used and necessary when armijo=False - reg : float, optional - Regularization parameter. Only used and necessary when armijo=False - Gc : array-like (ns,nt) - Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False - constC : array-like (ns,nt) - Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False - M : array-like (ns,nt), optional - Cost matrix between the features. Only used and necessary when armijo=False - alpha_min : float, optional - Minimum value for alpha - alpha_max : float, optional - Maximum value for alpha - - Returns - ------- - alpha : float - The optimal step size of the FW - fc : int - nb of function call. Useless here - f_val : float - The value of the cost for the next iteration +def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None, + numItermax=200, stopThr=1e-9, + stopThr2=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the general regularized OT problem or its semi-relaxed version with + conditional gradient or generalized conditional gradient depending on the + provided linear program solver. + The function solves the following optimization problem if set as a conditional gradient: - .. _references-solve-linesearch: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. - """ - if armijo: - alpha, fc, f_val = line_search_armijo( - cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max - ) - else: # requires symetric matrices - G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M) - if isinstance(M, int) or isinstance(M, float): - nx = get_backend(G, deltaG, C1, C2, constC) - else: - nx = get_backend(G, deltaG, C1, C2, constC, M) + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_1} \cdot f(\gamma) - dot = nx.dot(nx.dot(C1, deltaG), C2) - a = -2 * reg * nx.sum(dot * deltaG) - b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG)) - c = cost(G) + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - alpha = solve_1d_linesearch_quad(a, b, c) - if alpha_min is not None or alpha_max is not None: - alpha = np.clip(alpha, alpha_min, alpha_max) - fc = None - f_val = cost(G + alpha * deltaG) + \gamma^T \mathbf{1} &= \mathbf{b} (optional constraint) - return alpha, fc, f_val + \gamma &\geq 0 + where : + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) -def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, - stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): - r""" - Solve the general regularized OT problem with conditional gradient + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>` - The function solves the following optimization problem: + The function solves the following optimization problem if set a generalized conditional gradient: .. math:: \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg} \cdot f(\gamma) + \mathrm{reg_1}\cdot f(\gamma) + \mathrm{reg_2}\cdot\Omega(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} @@ -197,29 +159,39 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, \gamma &\geq 0 where : - - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - - :math:`f` is the regularization term (and `df` is its gradient) - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - - The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>` + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>` Parameters ---------- a : array-like, shape (ns,) samples weights in the source domain b : array-like, shape (nt,) - samples in the target domain + samples weights in the target domain M : array-like, shape (ns, nt) loss matrix - reg : float + f : function + Regularization function taking a transportation matrix as argument + df: function + Gradient of the regularization function taking a transportation matrix as argument + reg1 : float Regularization term >0 + reg2 : float, + Entropic Regularization term >0. Ignored if set to None. + lp_solver: function, + linear program solver for direction finding of the (generalized) conditional gradient. + If set to emd will solve the general regularized OT problem using cg. + If set to lp_semi_relaxed_OT will solve the general regularized semi-relaxed OT problem using cg. + If set to sinkhorn will solve the general regularized OT problem using generalized cg. + line_search: function, + Function to find the optimal step. Currently used instances are: + line_search_armijo (generic solver). solve_gromov_linesearch for (F)GW problem. + solve_semirelaxed_gromov_linesearch for sr(F)GW problem. gcg_linesearch for the Generalized cg. G0 : array-like, shape (ns,nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations - numItermaxEmd : int, optional - Max number of iterations for emd stopThr : float, optional Stop threshold on the relative variation (>0) stopThr2 : float, optional @@ -240,16 +212,20 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, .. _references-cg: + .. _references_gcg: References ---------- .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. + See Also -------- ot.lp.emd : Unregularized optimal ransport ot.bregman.sinkhorn : Entropic regularized optimal transport - """ a, b, M, G0 = list_to_array(a, b, M, G0) if isinstance(M, int) or isinstance(M, float): @@ -265,42 +241,45 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, if G0 is None: G = nx.outer(a, b) else: - G = G0 - - def cost(G): - return nx.sum(M * G) + reg * f(G) + # to not change G0 in place. + G = nx.copy(G0) - f_val = cost(G) + if reg2 is None: + def cost(G): + return nx.sum(M * G) + reg1 * f(G) + else: + def cost(G): + return nx.sum(M * G) + reg1 * f(G) + reg2 * nx.sum(G * nx.log(G)) + cost_G = cost(G) if log: - log['loss'].append(f_val) + log['loss'].append(cost_G) it = 0 if verbose: print('{:5s}|{:12s}|{:8s}|{:8s}'.format( 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, 0, 0)) while loop: it += 1 - old_fval = f_val - + old_cost_G = cost_G # problem linearization - Mi = M + reg * df(G) + Mi = M + reg1 * df(G) + + if not (reg2 is None): + Mi = Mi + reg2 * (1 + nx.log(G)) # set M positive - Mi += nx.min(Mi) + Mi = Mi + nx.min(Mi) # solve linear program - Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True) + Gc, innerlog_ = lp_solver(a, b, Mi, **kwargs) + # line search deltaG = Gc - G - # line search - alpha, fc, f_val = solve_linesearch( - cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, - alpha_min=0., alpha_max=1., **kwargs - ) + alpha, fc, cost_G = line_search(cost, G, deltaG, Mi, cost_G, **kwargs) G = G + alpha * deltaG @@ -308,29 +287,197 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, if it >= numItermax: loop = 0 - abs_delta_fval = abs(f_val - old_fval) - relative_delta_fval = abs_delta_fval / abs(f_val) - if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: + abs_delta_cost_G = abs(cost_G - old_cost_G) + relative_delta_cost_G = abs_delta_cost_G / abs(cost_G) + if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2: loop = 0 if log: - log['loss'].append(f_val) + log['loss'].append(cost_G) if verbose: if it % 20 == 0: print('{:5s}|{:12s}|{:8s}|{:8s}'.format( 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, relative_delta_cost_G, abs_delta_cost_G)) if log: - log.update(logemd) + log.update(innerlog_) return G, log else: return G +def cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, + numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9, + verbose=False, log=False, **kwargs): + r""" + Solve the general regularized OT problem with conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot f(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>` + + + Parameters + ---------- + a : array-like, shape (ns,) + samples weights in the source domain + b : array-like, shape (nt,) + samples in the target domain + M : array-like, shape (ns, nt) + loss matrix + reg : float + Regularization term >0 + G0 : array-like, shape (ns,nt), optional + initial guess (default is indep joint density) + line_search: function, + Function to find the optimal step. + Default is line_search_armijo. + numItermax : int, optional + Max number of iterations + numItermaxEmd : int, optional + Max number of iterations for emd + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + **kwargs : dict + Parameters for linesearch + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + .. _references-cg: + References + ---------- + + .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + + See Also + -------- + ot.lp.emd : Unregularized optimal ransport + ot.bregman.sinkhorn : Entropic regularized optimal transport + + """ + + def lp_solver(a, b, M, **kwargs): + return emd(a, b, M, numItermaxEmd, log=True) + + return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, + numItermax=numItermax, stopThr=stopThr, + stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) + + +def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, + numItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the general regularized and semi-relaxed OT problem with conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot f(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma &\geq 0 + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>` + + + Parameters + ---------- + a : array-like, shape (ns,) + samples weights in the source domain + b : array-like, shape (nt,) + currently estimated samples weights in the target domain + M : array-like, shape (ns, nt) + loss matrix + reg : float + Regularization term >0 + G0 : array-like, shape (ns,nt), optional + initial guess (default is indep joint density) + line_search: function, + Function to find the optimal step. + Default is the armijo line-search. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + **kwargs : dict + Parameters for linesearch + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + .. _references-cg: + References + ---------- + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2021. + + """ + + nx = get_backend(a, b) + + def lp_solver(a, b, Mi, **kwargs): + # get minimum by rows as binary mask + Gc = nx.ones(1, type_as=a) * (Mi == nx.reshape(nx.min(Mi, axis=1), (-1, 1))) + Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1)) + # return by default an empty inner_log + return Gc, {} + + return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, + numItermax=numItermax, stopThr=stopThr, + stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) + + def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, - numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False): + numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): r""" Solve the general regularized OT problem with the generalized conditional gradient @@ -403,81 +550,18 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, ot.optim.cg : conditional gradient """ - a, b, M, G0 = list_to_array(a, b, M, G0) - nx = get_backend(a, b, M) - - loop = 1 - - if log: - log = {'loss': []} - - if G0 is None: - G = nx.outer(a, b) - else: - G = G0 - - def cost(G): - return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G) - - f_val = cost(G) - if log: - log['loss'].append(f_val) - - it = 0 - - if verbose: - print('{:5s}|{:12s}|{:8s}|{:8s}'.format( - 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) - - while loop: - - it += 1 - old_fval = f_val - - # problem linearization - Mi = M + reg2 * df(G) - - # solve linear program with Sinkhorn - # Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax) - Gc = sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax) - deltaG = Gc - G - - # line search - dcost = Mi + reg1 * (1 + nx.log(G)) # ?? - alpha, fc, f_val = line_search_armijo( - cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1. - ) + def lp_solver(a, b, Mi, **kwargs): + return sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax, log=True, **kwargs) - G = G + alpha * deltaG + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs) - # test convergence - if it >= numItermax: - loop = 0 + return generic_conditional_gradient(a, b, M, f, df, reg2, reg1, lp_solver, line_search, G0=G0, + numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) - abs_delta_fval = abs(f_val - old_fval) - relative_delta_fval = abs_delta_fval / abs(f_val) - if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: - loop = 0 - - if log: - log['loss'].append(f_val) - - if verbose: - if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}|{:8s}'.format( - 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) - - if log: - return G, log - else: - return G - - -def solve_1d_linesearch_quad(a, b, c): +def solve_1d_linesearch_quad(a, b): r""" For any convex or non-convex 1d quadratic function `f`, solve the following problem: @@ -487,7 +571,7 @@ def solve_1d_linesearch_quad(a, b, c): Parameters ---------- - a,b,c : float + a,b : float or tensors (1,) The coefficients of the quadratic function Returns @@ -495,15 +579,11 @@ def solve_1d_linesearch_quad(a, b, c): x : float The optimal value which leads to the minimal cost """ - f0 = c - df0 = b - f1 = a + f0 + df0 - if a > 0: # convex - minimum = min(1, max(0, np.divide(-b, 2.0 * a))) + minimum = min(1., max(0., -b / (2.0 * a))) return minimum else: # non convex - if f0 > f1: - return 1 + if a + b < 0: + return 1. else: - return 0 + return 0. |