diff options
Diffstat (limited to 'ot/optim.py')
-rw-r--r-- | ot/optim.py | 155 |
1 files changed, 93 insertions, 62 deletions
diff --git a/ot/optim.py b/ot/optim.py index 0359343..6822e4e 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -12,6 +12,8 @@ import numpy as np from scipy.optimize.linesearch import scalar_search_armijo from .lp import emd from .bregman import sinkhorn +from ot.utils import list_to_array +from .backend import get_backend # The corresponding scipy function does not work for matrices @@ -21,25 +23,25 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, """ Armijo linesearch function that works with matrices - find an approximate minimum of f(xk+alpha*pk) that satifies the + Find an approximate minimum of :math:`f(x_k + \\alpha \cdot p_k)` that satisfies the armijo conditions. Parameters ---------- f : callable loss function - xk : ndarray + xk : array-like initial position - pk : ndarray + pk : array-like descent direction - gfk : ndarray - gradient of f at xk + gfk : array-like + gradient of `f` at :math:`x_k` old_fval : float - loss value at xk + loss value at :math:`x_k` args : tuple, optional - arguments given to f + arguments given to `f` c1 : float, optional - c1 const in armijo rule (>0) + :math:`c_1` const in armijo rule (>0) alpha0 : float, optional initial step (>0) @@ -53,7 +55,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, loss value at step alpha """ - xk = np.atleast_1d(xk) + + xk, pk, gfk = list_to_array(xk, pk, gfk) + nx = get_backend(xk, pk) + + if len(xk.shape) == 0: + xk = nx.reshape(xk, (-1,)) + fc = [0] def phi(alpha1): @@ -65,7 +73,7 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, else: phi0 = old_fval - derphi0 = np.sum(pk * gfk) # Quickfix for matrices + derphi0 = nx.sum(pk * gfk) # Quickfix for matrices alpha, phi1 = scalar_search_armijo( phi, phi0, derphi0, c1=c1, alpha0=alpha0) @@ -79,55 +87,64 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations + Parameters ---------- cost : method Cost in the FW for the linesearch - G : ndarray, shape(ns,nt) + G : array-like, shape(ns,nt) The transport map at a given iteration of the FW - deltaG : ndarray (ns,nt) + deltaG : array-like (ns,nt) Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration - Mi : ndarray (ns,nt) + Mi : array-like (ns,nt) Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost - f_val : float - Value of the cost at G + f_val : float + Value of the cost at `G` armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. - C1 : ndarray (ns,ns), optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. + If there is convergence issues use False. + C1 : array-like (ns,ns), optional Structure matrix in the source domain. Only used and necessary when armijo=False - C2 : ndarray (nt,nt), optional + C2 : array-like (nt,nt), optional Structure matrix in the target domain. Only used and necessary when armijo=False reg : float, optional - Regularization parameter. Only used and necessary when armijo=False - Gc : ndarray (ns,nt) + Regularization parameter. Only used and necessary when armijo=False + Gc : array-like (ns,nt) Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False - constC : ndarray (ns,nt) - Constant for the gromov cost. See [24]. Only used and necessary when armijo=False - M : ndarray (ns,nt), optional + constC : array-like (ns,nt) + Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False + M : array-like (ns,nt), optional Cost matrix between the features. Only used and necessary when armijo=False + Returns ------- alpha : float - The optimal step size of the FW + The optimal step size of the FW fc : int - nb of function call. Useless here - f_val : float - The value of the cost for the next iteration + nb of function call. Useless here + f_val : float + The value of the cost for the next iteration + + + .. _references-solve-linesearch: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain - and Courty Nicolas + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ if armijo: alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) else: # requires symetric matrices - dot1 = np.dot(C1, deltaG) - dot12 = dot1.dot(C2) - a = -2 * reg * np.sum(dot12 * deltaG) - b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) + G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M) + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(G, deltaG, C1, C2, constC) + else: + nx = get_backend(G, deltaG, C1, C2, constC, M) + + dot = nx.dot(nx.dot(C1, deltaG), C2) + a = -2 * reg * nx.sum(dot * deltaG) + b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG)) c = cost(G) alpha = solve_1d_linesearch_quad(a, b, c) @@ -145,33 +162,33 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg*f(\gamma) + \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg} \cdot f(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>` Parameters ---------- - a : ndarray, shape (ns,) + a : array-like, shape (ns,) samples weights in the source domain - b : ndarray, shape (nt,) + b : array-like, shape (nt,) samples in the target domain - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) loss matrix reg : float Regularization term >0 - G0 : ndarray, shape (ns,nt), optional + G0 : array-like, shape (ns,nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations @@ -196,6 +213,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, log dictionary return only if log==True in parameters + .. _references-cg: References ---------- @@ -207,6 +225,11 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, ot.bregman.sinkhorn : Entropic regularized optimal transport """ + a, b, M, G0 = list_to_array(a, b, M, G0) + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(a, b) + else: + nx = get_backend(a, b, M) loop = 1 @@ -214,12 +237,12 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, log = {'loss': []} if G0 is None: - G = np.outer(a, b) + G = nx.outer(a, b) else: G = G0 def cost(G): - return np.sum(M * G) + reg * f(G) + return nx.sum(M * G) + reg * f(G) f_val = cost(G) if log: @@ -240,7 +263,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, # problem linearization Mi = M + reg * df(G) # set M positive - Mi += Mi.min() + Mi += nx.min(Mi) # solve linear program Gc = emd(a, b, Mi, numItermax=numItermaxEmd) @@ -286,36 +309,36 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg1\cdot\Omega(\gamma) + reg2\cdot f(\gamma) + \gamma = arg\min_\gamma <\gamma,M>_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - The algorithm used for solving the problem is the generalized conditional gradient as discussed in [5,7]_ + The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>` Parameters ---------- - a : ndarray, shape (ns,) + a : array-like, shape (ns,) samples weights in the source domain - b : ndarrayv (nt,) + b : array-like, (nt,) samples in the target domain - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) loss matrix reg1 : float Entropic Regularization term >0 reg2 : float Second Regularization term >0 - G0 : ndarray, shape (ns, nt), optional + G0 : array-like, shape (ns, nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations @@ -337,9 +360,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, log : dict log dictionary return only if log==True in parameters + + .. _references-gcg: References ---------- + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. See Also @@ -347,6 +374,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, ot.optim.cg : conditional gradient """ + a, b, M, G0 = list_to_array(a, b, M, G0) + nx = get_backend(a, b, M) loop = 1 @@ -354,12 +383,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, log = {'loss': []} if G0 is None: - G = np.outer(a, b) + G = nx.outer(a, b) else: G = G0 def cost(G): - return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G) + return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G) f_val = cost(G) if log: @@ -387,7 +416,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, deltaG = Gc - G # line search - dcost = Mi + reg1 * (1 + np.log(G)) # ?? + dcost = Mi + reg1 * (1 + nx.log(G)) # ?? alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val) G = G + alpha * deltaG @@ -419,9 +448,11 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, def solve_1d_linesearch_quad(a, b, c): """ - For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem: + For any convex or non-convex 1d quadratic function `f`, solve the following problem: + .. math:: - \argmin f(x)=a*x^{2}+b*x+c + + arg\min_{0 \leq x \leq 1} f(x) = ax^{2} + bx + c Parameters ---------- |