From a5930d3b3a446bf860d6dfacc1e17151fae1dd1d Mon Sep 17 00:00:00 2001 From: Cédric Vincent-Cuaz Date: Thu, 9 Mar 2023 14:21:33 +0100 Subject: [MRG] Semi-relaxed (fused) gromov-wasserstein divergence and improvements of gromov-wasserstein solvers (#431) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * maj gw/ srgw/ generic cg solver * correct pep8 on current state * fix bug previous tests * fix pep8 * fix bug srGW constC in loss and gradient * fix doc html * fix doc html * start updating test_optim.py * update tests gromov and optim - plus fix gromov dependencies * add symmetry feature to entropic gw * add symmetry feature to entropic gw * add exemple for sr(F)GW matchings * small stuff * remove (reg,M) from line-search/ complete srgw tests with backend * remove backend repetitions / rename fG to costG/ fix innerlog to True * fix pep8 * take comments into account / new nx parameters still to test * factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions * split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions * manual documentaion gromov * remove circular autosummary * trying stuff * debug documentation * alphabetic ordering of module * merge into branch * add note in entropic gw solvers --------- Co-authored-by: Rémi Flamary --- ot/optim.py | 460 ++++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 260 insertions(+), 200 deletions(-) (limited to 'ot/optim.py') diff --git a/ot/optim.py b/ot/optim.py index 5a1d605..201f898 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- """ -Generic solvers for regularized OT +Generic solvers for regularized OT or its semi-relaxed version. """ # Author: Remi Flamary # Titouan Vayer -# +# Cédric Vincent-Cuaz # License: MIT License import numpy as np @@ -27,7 +27,7 @@ with warnings.catch_warnings(): def line_search_armijo( f, xk, pk, gfk, old_fval, args=(), c1=1e-4, - alpha0=0.99, alpha_min=None, alpha_max=None + alpha0=0.99, alpha_min=None, alpha_max=None, nx=None, **kwargs ): r""" Armijo linesearch function that works with matrices @@ -57,7 +57,8 @@ def line_search_armijo( minimum value for alpha alpha_max : float, optional maximum value for alpha - + nx : backend, optional + If let to its default value None, a backend test will be conducted. Returns ------- alpha : float @@ -68,9 +69,9 @@ def line_search_armijo( loss value at step alpha """ - - xk, pk, gfk = list_to_array(xk, pk, gfk) - nx = get_backend(xk, pk) + if nx is None: + xk, pk, gfk = list_to_array(xk, pk, gfk) + nx = get_backend(xk, pk) if len(xk.shape) == 0: xk = nx.reshape(xk, (-1,)) @@ -98,97 +99,38 @@ def line_search_armijo( return float(alpha), fc[0], phi1 -def solve_linesearch( - cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None, - reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None -): - """ - Solve the linesearch in the FW iterations - - Parameters - ---------- - cost : method - Cost in the FW for the linesearch - G : array-like, shape(ns,nt) - The transport map at a given iteration of the FW - deltaG : array-like (ns,nt) - Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration - Mi : array-like (ns,nt) - Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost - f_val : float - Value of the cost at `G` - armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. - C1 : array-like (ns,ns), optional - Structure matrix in the source domain. Only used and necessary when armijo=False - C2 : array-like (nt,nt), optional - Structure matrix in the target domain. Only used and necessary when armijo=False - reg : float, optional - Regularization parameter. Only used and necessary when armijo=False - Gc : array-like (ns,nt) - Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False - constC : array-like (ns,nt) - Constant for the gromov cost. See :ref:`[24] `. Only used and necessary when armijo=False - M : array-like (ns,nt), optional - Cost matrix between the features. Only used and necessary when armijo=False - alpha_min : float, optional - Minimum value for alpha - alpha_max : float, optional - Maximum value for alpha - - Returns - ------- - alpha : float - The optimal step size of the FW - fc : int - nb of function call. Useless here - f_val : float - The value of the cost for the next iteration +def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None, + numItermax=200, stopThr=1e-9, + stopThr2=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the general regularized OT problem or its semi-relaxed version with + conditional gradient or generalized conditional gradient depending on the + provided linear program solver. + The function solves the following optimization problem if set as a conditional gradient: - .. _references-solve-linesearch: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. - """ - if armijo: - alpha, fc, f_val = line_search_armijo( - cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max - ) - else: # requires symetric matrices - G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M) - if isinstance(M, int) or isinstance(M, float): - nx = get_backend(G, deltaG, C1, C2, constC) - else: - nx = get_backend(G, deltaG, C1, C2, constC, M) + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_1} \cdot f(\gamma) - dot = nx.dot(nx.dot(C1, deltaG), C2) - a = -2 * reg * nx.sum(dot * deltaG) - b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG)) - c = cost(G) + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - alpha = solve_1d_linesearch_quad(a, b, c) - if alpha_min is not None or alpha_max is not None: - alpha = np.clip(alpha, alpha_min, alpha_max) - fc = None - f_val = cost(G + alpha * deltaG) + \gamma^T \mathbf{1} &= \mathbf{b} (optional constraint) - return alpha, fc, f_val + \gamma &\geq 0 + where : + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) -def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, - stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): - r""" - Solve the general regularized OT problem with conditional gradient + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` - The function solves the following optimization problem: + The function solves the following optimization problem if set a generalized conditional gradient: .. math:: \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg} \cdot f(\gamma) + \mathrm{reg_1}\cdot f(\gamma) + \mathrm{reg_2}\cdot\Omega(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} @@ -197,29 +139,39 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, \gamma &\geq 0 where : - - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - - :math:`f` is the regularization term (and `df` is its gradient) - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - - The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] ` Parameters ---------- a : array-like, shape (ns,) samples weights in the source domain b : array-like, shape (nt,) - samples in the target domain + samples weights in the target domain M : array-like, shape (ns, nt) loss matrix - reg : float + f : function + Regularization function taking a transportation matrix as argument + df: function + Gradient of the regularization function taking a transportation matrix as argument + reg1 : float Regularization term >0 + reg2 : float, + Entropic Regularization term >0. Ignored if set to None. + lp_solver: function, + linear program solver for direction finding of the (generalized) conditional gradient. + If set to emd will solve the general regularized OT problem using cg. + If set to lp_semi_relaxed_OT will solve the general regularized semi-relaxed OT problem using cg. + If set to sinkhorn will solve the general regularized OT problem using generalized cg. + line_search: function, + Function to find the optimal step. Currently used instances are: + line_search_armijo (generic solver). solve_gromov_linesearch for (F)GW problem. + solve_semirelaxed_gromov_linesearch for sr(F)GW problem. gcg_linesearch for the Generalized cg. G0 : array-like, shape (ns,nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations - numItermaxEmd : int, optional - Max number of iterations for emd stopThr : float, optional Stop threshold on the relative variation (>0) stopThr2 : float, optional @@ -240,16 +192,20 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, .. _references-cg: + .. _references_gcg: References ---------- .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. + See Also -------- ot.lp.emd : Unregularized optimal ransport ot.bregman.sinkhorn : Entropic regularized optimal transport - """ a, b, M, G0 = list_to_array(a, b, M, G0) if isinstance(M, int) or isinstance(M, float): @@ -265,42 +221,45 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, if G0 is None: G = nx.outer(a, b) else: - G = G0 - - def cost(G): - return nx.sum(M * G) + reg * f(G) + # to not change G0 in place. + G = nx.copy(G0) - f_val = cost(G) + if reg2 is None: + def cost(G): + return nx.sum(M * G) + reg1 * f(G) + else: + def cost(G): + return nx.sum(M * G) + reg1 * f(G) + reg2 * nx.sum(G * nx.log(G)) + cost_G = cost(G) if log: - log['loss'].append(f_val) + log['loss'].append(cost_G) it = 0 if verbose: print('{:5s}|{:12s}|{:8s}|{:8s}'.format( 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, 0, 0)) while loop: it += 1 - old_fval = f_val - + old_cost_G = cost_G # problem linearization - Mi = M + reg * df(G) + Mi = M + reg1 * df(G) + + if not (reg2 is None): + Mi = Mi + reg2 * (1 + nx.log(G)) # set M positive - Mi += nx.min(Mi) + Mi = Mi + nx.min(Mi) # solve linear program - Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True) + Gc, innerlog_ = lp_solver(a, b, Mi, **kwargs) + # line search deltaG = Gc - G - # line search - alpha, fc, f_val = solve_linesearch( - cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, - alpha_min=0., alpha_max=1., **kwargs - ) + alpha, fc, cost_G = line_search(cost, G, deltaG, Mi, cost_G, **kwargs) G = G + alpha * deltaG @@ -308,29 +267,197 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, if it >= numItermax: loop = 0 - abs_delta_fval = abs(f_val - old_fval) - relative_delta_fval = abs_delta_fval / abs(f_val) - if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: + abs_delta_cost_G = abs(cost_G - old_cost_G) + relative_delta_cost_G = abs_delta_cost_G / abs(cost_G) + if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2: loop = 0 if log: - log['loss'].append(f_val) + log['loss'].append(cost_G) if verbose: if it % 20 == 0: print('{:5s}|{:12s}|{:8s}|{:8s}'.format( 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, relative_delta_cost_G, abs_delta_cost_G)) if log: - log.update(logemd) + log.update(innerlog_) return G, log else: return G +def cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, + numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9, + verbose=False, log=False, **kwargs): + r""" + Solve the general regularized OT problem with conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot f(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` + + + Parameters + ---------- + a : array-like, shape (ns,) + samples weights in the source domain + b : array-like, shape (nt,) + samples in the target domain + M : array-like, shape (ns, nt) + loss matrix + reg : float + Regularization term >0 + G0 : array-like, shape (ns,nt), optional + initial guess (default is indep joint density) + line_search: function, + Function to find the optimal step. + Default is line_search_armijo. + numItermax : int, optional + Max number of iterations + numItermaxEmd : int, optional + Max number of iterations for emd + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + **kwargs : dict + Parameters for linesearch + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + .. _references-cg: + References + ---------- + + .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + + See Also + -------- + ot.lp.emd : Unregularized optimal ransport + ot.bregman.sinkhorn : Entropic regularized optimal transport + + """ + + def lp_solver(a, b, M, **kwargs): + return emd(a, b, M, numItermaxEmd, log=True) + + return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, + numItermax=numItermax, stopThr=stopThr, + stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) + + +def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, + numItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): + r""" + Solve the general regularized and semi-relaxed OT problem with conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot f(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma &\geq 0 + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` + + + Parameters + ---------- + a : array-like, shape (ns,) + samples weights in the source domain + b : array-like, shape (nt,) + currently estimated samples weights in the target domain + M : array-like, shape (ns, nt) + loss matrix + reg : float + Regularization term >0 + G0 : array-like, shape (ns,nt), optional + initial guess (default is indep joint density) + line_search: function, + Function to find the optimal step. + Default is the armijo line-search. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + **kwargs : dict + Parameters for linesearch + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + .. _references-cg: + References + ---------- + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2021. + + """ + + nx = get_backend(a, b) + + def lp_solver(a, b, Mi, **kwargs): + # get minimum by rows as binary mask + Gc = nx.ones(1, type_as=a) * (Mi == nx.reshape(nx.min(Mi, axis=1), (-1, 1))) + Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1)) + # return by default an empty inner_log + return Gc, {} + + return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, + numItermax=numItermax, stopThr=stopThr, + stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) + + def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, - numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False): + numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): r""" Solve the general regularized OT problem with the generalized conditional gradient @@ -403,81 +530,18 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, ot.optim.cg : conditional gradient """ - a, b, M, G0 = list_to_array(a, b, M, G0) - nx = get_backend(a, b, M) - - loop = 1 - - if log: - log = {'loss': []} - - if G0 is None: - G = nx.outer(a, b) - else: - G = G0 - - def cost(G): - return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G) - - f_val = cost(G) - if log: - log['loss'].append(f_val) - - it = 0 - - if verbose: - print('{:5s}|{:12s}|{:8s}|{:8s}'.format( - 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) - while loop: - - it += 1 - old_fval = f_val - - # problem linearization - Mi = M + reg2 * df(G) - - # solve linear program with Sinkhorn - # Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax) - Gc = sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax) - - deltaG = Gc - G - - # line search - dcost = Mi + reg1 * (1 + nx.log(G)) # ?? - alpha, fc, f_val = line_search_armijo( - cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1. - ) - - G = G + alpha * deltaG - - # test convergence - if it >= numItermax: - loop = 0 - - abs_delta_fval = abs(f_val - old_fval) - relative_delta_fval = abs_delta_fval / abs(f_val) + def lp_solver(a, b, Mi, **kwargs): + return sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax, log=True, **kwargs) - if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: - loop = 0 + def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + return line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs) - if log: - log['loss'].append(f_val) + return generic_conditional_gradient(a, b, M, f, df, reg2, reg1, lp_solver, line_search, G0=G0, + numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) - if verbose: - if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}|{:8s}'.format( - 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) - if log: - return G, log - else: - return G - - -def solve_1d_linesearch_quad(a, b, c): +def solve_1d_linesearch_quad(a, b): r""" For any convex or non-convex 1d quadratic function `f`, solve the following problem: @@ -487,7 +551,7 @@ def solve_1d_linesearch_quad(a, b, c): Parameters ---------- - a,b,c : float + a,b : float or tensors (1,) The coefficients of the quadratic function Returns @@ -495,15 +559,11 @@ def solve_1d_linesearch_quad(a, b, c): x : float The optimal value which leads to the minimal cost """ - f0 = c - df0 = b - f1 = a + f0 + df0 - if a > 0: # convex - minimum = min(1, max(0, np.divide(-b, 2.0 * a))) + minimum = min(1., max(0., -b / (2.0 * a))) return minimum else: # non convex - if f0 > f1: - return 1 + if a + b < 0: + return 1. else: - return 0 + return 0. -- cgit v1.2.3 From 583501652517c4f1dbd8572e9f942551a9e54a1f Mon Sep 17 00:00:00 2001 From: Cédric Vincent-Cuaz Date: Thu, 16 Mar 2023 08:05:54 +0100 Subject: [MRG] fix bugs of gw_entropic and armijo to run on gpu (#446) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * maj gw/ srgw/ generic cg solver * correct pep8 on current state * fix bug previous tests * fix pep8 * fix bug srGW constC in loss and gradient * fix doc html * fix doc html * start updating test_optim.py * update tests gromov and optim - plus fix gromov dependencies * add symmetry feature to entropic gw * add symmetry feature to entropic gw * add exemple for sr(F)GW matchings * small stuff * remove (reg,M) from line-search/ complete srgw tests with backend * remove backend repetitions / rename fG to costG/ fix innerlog to True * fix pep8 * take comments into account / new nx parameters still to test * factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions * split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions * manual documentaion gromov * remove circular autosummary * trying stuff * debug documentation * alphabetic ordering of module * merge into branch * add note in entropic gw solvers * fix exemples/gromov doc * add fixed issue to releases.md * fix bugs of gw_entropic and armijo to run on gpu * add pr to releases.md * fix pep8 * fix call to backend in line_search_armijo * correct docstring generic_conditional_gradient --------- Co-authored-by: Rémi Flamary --- RELEASES.md | 3 ++- ot/gromov/_bregman.py | 5 +---- ot/optim.py | 38 ++++++++++++++++++++++++++--------- test/test_gromov.py | 16 +++++++++++++++ test/test_optim.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++++--- 5 files changed, 100 insertions(+), 17 deletions(-) (limited to 'ot/optim.py') diff --git a/RELEASES.md b/RELEASES.md index da4d7bb..b6e12d9 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -47,6 +47,7 @@ PR #413) that explicitly specified `stopThr=1e-9` (Issue #421, PR #422). - Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425) - Fixed an issue with the documentation gallery section (PR #444) +- Fixed issues with cuda variables for `line_search_armijo` and `entropic_gromov_wasserstein` (Issue #445, #PR 446) ## 0.8.2 @@ -571,4 +572,4 @@ It provides the following solvers: * Optimal transport for domain adaptation with group lasso regularization * Conditional gradient and Generalized conditional gradient for regularized OT. -Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. +Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. \ No newline at end of file diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index 5b2f959..b0cccfb 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -11,9 +11,6 @@ Bregman projections solvers for entropic Gromov-Wasserstein # # License: MIT License -import numpy as np - - from ..bregman import sinkhorn from ..utils import dist, list_to_array, check_random_state from ..backend import get_backend @@ -109,7 +106,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, T = G0 constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx) if symmetric is None: - symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) if not symmetric: constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx) cpt = 0 diff --git a/ot/optim.py b/ot/optim.py index 201f898..58e5596 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -35,6 +35,9 @@ def line_search_armijo( Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the armijo conditions. + .. note:: If the loss function f returns a float (resp. a 1d array) then + the returned alpha and fa are float (resp. 1d arrays). + Parameters ---------- f : callable @@ -45,7 +48,7 @@ def line_search_armijo( descent direction gfk : array-like gradient of `f` at :math:`x_k` - old_fval : float + old_fval : float or 1d array loss value at :math:`x_k` args : tuple, optional arguments given to `f` @@ -61,42 +64,59 @@ def line_search_armijo( If let to its default value None, a backend test will be conducted. Returns ------- - alpha : float + alpha : float or 1d array step that satisfy armijo conditions fc : int nb of function call - fa : float + fa : float or 1d array loss value at step alpha """ if nx is None: xk, pk, gfk = list_to_array(xk, pk, gfk) - nx = get_backend(xk, pk) + xk0, pk0 = xk, pk + nx = get_backend(xk0, pk0) + else: + xk0, pk0 = xk, pk if len(xk.shape) == 0: xk = nx.reshape(xk, (-1,)) + xk = nx.to_numpy(xk) + pk = nx.to_numpy(pk) + gfk = nx.to_numpy(gfk) + fc = [0] def phi(alpha1): + # The callable function operates on nx backend fc[0] += 1 - return f(xk + alpha1 * pk, *args) + alpha10 = nx.from_numpy(alpha1) + fval = f(xk0 + alpha10 * pk0, *args) + if type(fval) is float: + # prevent bug from nx.to_numpy that can look for .cpu or .gpu + return fval + else: + return nx.to_numpy(fval) if old_fval is None: phi0 = phi(0.) - else: + elif type(old_fval) is float: + # prevent bug from nx.to_numpy that can look for .cpu or .gpu phi0 = old_fval + else: + phi0 = nx.to_numpy(old_fval) - derphi0 = nx.sum(pk * gfk) # Quickfix for matrices + derphi0 = np.sum(pk * gfk) # Quickfix for matrices alpha, phi1 = scalar_search_armijo( phi, phi0, derphi0, c1=c1, alpha0=alpha0) if alpha is None: - return 0., fc[0], phi0 + return 0., fc[0], nx.from_numpy(phi0, type_as=xk0) else: if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) - return float(alpha), fc[0], phi1 + return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0) def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None, diff --git a/test/test_gromov.py b/test/test_gromov.py index cfccce7..80b6df4 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -214,6 +214,7 @@ def test_gromov2_gradients(): C11 = torch.tensor(C1, requires_grad=True, device=device) C12 = torch.tensor(C2, requires_grad=True, device=device) + # Test with exact line-search val = ot.gromov_wasserstein2(C11, C12, p1, q1) val.backward() @@ -224,6 +225,21 @@ def test_gromov2_gradients(): assert C11.shape == C11.grad.shape assert C12.shape == C12.grad.shape + # Test with armijo line-search + q1.grad = None + p1.grad = None + C11.grad = None + C12.grad = None + val = ot.gromov_wasserstein2(C11, C12, p1, q1, armijo=True) + + val.backward() + + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + def test_gw_helper_backend(nx): n_samples = 20 # nb samples diff --git a/test/test_optim.py b/test/test_optim.py index 129fe22..a43e704 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -135,16 +135,18 @@ def test_line_search_armijo(nx): xk = np.array([[0.25, 0.25], [0.25, 0.25]]) pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) - old_fval = -123 + old_fval = -123. xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk) + def f(x): + return 1. # Should not throw an exception and return 0. for alpha alpha, a, b = ot.optim.line_search_armijo( - lambda x: 1, xkb, pkb, gfkb, old_fval + f, xkb, pkb, gfkb, old_fval ) alpha_np, anp, bnp = ot.optim.line_search_armijo( - lambda x: 1, xk, pk, gfk, old_fval + f, xk, pk, gfk, old_fval ) assert a == anp assert b == bnp @@ -182,3 +184,50 @@ def test_line_search_armijo(nx): old_fval = f(xk) alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) np.testing.assert_allclose(alpha, 0.1) + + +def test_line_search_armijo_dtype_device(nx): + for tp in nx.__type_list__: + def f(x): + return nx.sum((x - 5.0) ** 2) + + def grad(x): + return 2 * (x - 5.0) + + xk = np.array([[[-5.0, -5.0]]]) + pk = np.array([[[100.0, 100.0]]]) + xkb, pkb = nx.from_numpy(xk, pk, type_as=tp) + gfkb = grad(xkb) + old_fval = f(xkb) + + # chech the case where the optimum is on the direction + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval) + alpha = nx.to_numpy(alpha) + np.testing.assert_allclose(alpha, 0.1) + nx.assert_same_dtype_device(old_fval, fval) + + # check the case where the direction is not far enough + pk = np.array([[[3.0, 3.0]]]) + pkb = nx.from_numpy(pk, type_as=tp) + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval, alpha0=1.0) + alpha = nx.to_numpy(alpha) + np.testing.assert_allclose(alpha, 1.0) + nx.assert_same_dtype_device(old_fval, fval) + + # check the case where checking the wrong direction + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, -pkb, gfkb, old_fval) + alpha = nx.to_numpy(alpha) + + assert alpha <= 0 + nx.assert_same_dtype_device(old_fval, fval) + + # check the case where the point is not a vector + xkb = nx.from_numpy(np.array(-5.0), type_as=tp) + pkb = nx.from_numpy(np.array(100), type_as=tp) + gfkb = grad(xkb) + old_fval = f(xkb) + alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval) + alpha = nx.to_numpy(alpha) + + np.testing.assert_allclose(alpha, 0.1) + nx.assert_same_dtype_device(old_fval, fval) -- cgit v1.2.3