diff options
author | Gard Spreemann <gspr@nonempty.org> | 2020-01-20 14:07:53 +0100 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2020-01-20 14:07:53 +0100 |
commit | bdfb24ff37ea777d6e266b145047cd4e281ebac3 (patch) | |
tree | 00cbac5f3dc25a4ee76164828abd72c1cbab37cc /ot/optim.py | |
parent | abc441b00f0fe2fa4ef0efc4e1aa67b27cca9a13 (diff) | |
parent | 5e70a77fbb2feec513f21c9ef65dcc535329ace6 (diff) |
Merge tag '0.6.0' into debian/sid
Diffstat (limited to 'ot/optim.py')
-rw-r--r-- | ot/optim.py | 440 |
1 files changed, 440 insertions, 0 deletions
diff --git a/ot/optim.py b/ot/optim.py new file mode 100644 index 0000000..0abd9e9 --- /dev/null +++ b/ot/optim.py @@ -0,0 +1,440 @@ +# -*- coding: utf-8 -*- +""" +Optimization algorithms for OT +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# Titouan Vayer <titouan.vayer@irisa.fr> +# +# License: MIT License + +import numpy as np +from scipy.optimize.linesearch import scalar_search_armijo +from .lp import emd +from .bregman import sinkhorn + +# The corresponding scipy function does not work for matrices + + +def line_search_armijo(f, xk, pk, gfk, old_fval, + args=(), c1=1e-4, alpha0=0.99): + """ + Armijo linesearch function that works with matrices + + find an approximate minimum of f(xk+alpha*pk) that satifies the + armijo conditions. + + Parameters + ---------- + f : callable + loss function + xk : ndarray + initial position + pk : ndarray + descent direction + gfk : ndarray + gradient of f at xk + old_fval : float + loss value at xk + args : tuple, optional + arguments given to f + c1 : float, optional + c1 const in armijo rule (>0) + alpha0 : float, optional + initial step (>0) + + Returns + ------- + alpha : float + step that satisfy armijo conditions + fc : int + nb of function call + fa : float + loss value at step alpha + + """ + xk = np.atleast_1d(xk) + fc = [0] + + def phi(alpha1): + fc[0] += 1 + return f(xk + alpha1 * pk, *args) + + if old_fval is None: + phi0 = phi(0.) + else: + phi0 = old_fval + + derphi0 = np.sum(pk * gfk) # Quickfix for matrices + alpha, phi1 = scalar_search_armijo( + phi, phi0, derphi0, c1=c1, alpha0=alpha0) + + return alpha, fc[0], phi1 + + +def solve_linesearch(cost, G, deltaG, Mi, f_val, + armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): + """ + Solve the linesearch in the FW iterations + Parameters + ---------- + cost : method + Cost in the FW for the linesearch + G : ndarray, shape(ns,nt) + The transport map at a given iteration of the FW + deltaG : ndarray (ns,nt) + Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration + Mi : ndarray (ns,nt) + Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost + f_val : float + Value of the cost at G + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. + If there is convergence issues use False. + C1 : ndarray (ns,ns), optional + Structure matrix in the source domain. Only used and necessary when armijo=False + C2 : ndarray (nt,nt), optional + Structure matrix in the target domain. Only used and necessary when armijo=False + reg : float, optional + Regularization parameter. Only used and necessary when armijo=False + Gc : ndarray (ns,nt) + Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False + constC : ndarray (ns,nt) + Constant for the gromov cost. See [24]. Only used and necessary when armijo=False + M : ndarray (ns,nt), optional + Cost matrix between the features. Only used and necessary when armijo=False + Returns + ------- + alpha : float + The optimal step size of the FW + fc : int + nb of function call. Useless here + f_val : float + The value of the cost for the next iteration + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + if armijo: + alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) + else: # requires symetric matrices + dot1 = np.dot(C1, deltaG) + dot12 = dot1.dot(C2) + a = -2 * reg * np.sum(dot12 * deltaG) + b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) + c = cost(G) + + alpha = solve_1d_linesearch_quad(a, b, c) + fc = None + f_val = cost(G + alpha * deltaG) + + return alpha, fc, f_val + + +def cg(a, b, M, reg, f, df, G0=None, numItermax=200, + stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): + """ + Solve the general regularized OT problem with conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg*f(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`f` is the regularization term ( and df is its gradient) + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + + + Parameters + ---------- + a : ndarray, shape (ns,) + samples weights in the source domain + b : ndarray, shape (nt,) + samples in the target domain + M : ndarray, shape (ns, nt) + loss matrix + reg : float + Regularization term >0 + G0 : ndarray, shape (ns,nt), optional + initial guess (default is indep joint density) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on the relative variation (>0) + stopThr2 : float, optional + Stop threshol on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + **kwargs : dict + Parameters for linesearch + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + + See Also + -------- + ot.lp.emd : Unregularized optimal ransport + ot.bregman.sinkhorn : Entropic regularized optimal transport + + """ + + loop = 1 + + if log: + log = {'loss': []} + + if G0 is None: + G = np.outer(a, b) + else: + G = G0 + + def cost(G): + return np.sum(M * G) + reg * f(G) + + f_val = cost(G) + if log: + log['loss'].append(f_val) + + it = 0 + + if verbose: + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) + + while loop: + + it += 1 + old_fval = f_val + + # problem linearization + Mi = M + reg * df(G) + # set M positive + Mi += Mi.min() + + # solve linear program + Gc = emd(a, b, Mi) + + deltaG = Gc - G + + # line search + alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) + + G = G + alpha * deltaG + + # test convergence + if it >= numItermax: + loop = 0 + + abs_delta_fval = abs(f_val - old_fval) + relative_delta_fval = abs_delta_fval / abs(f_val) + if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: + loop = 0 + + if log: + log['loss'].append(f_val) + + if verbose: + if it % 20 == 0: + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) + + if log: + return G, log + else: + return G + + +def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, + numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False): + """ + Solve the general regularized OT problem with the generalized conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg1\cdot\Omega(\gamma) + reg2\cdot f(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`f` is the regularization term ( and df is its gradient) + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the generalized conditional gradient as discussed in [5,7]_ + + + Parameters + ---------- + a : ndarray, shape (ns,) + samples weights in the source domain + b : ndarrayv (nt,) + samples in the target domain + M : ndarray, shape (ns, nt) + loss matrix + reg1 : float + Entropic Regularization term >0 + reg2 : float + Second Regularization term >0 + G0 : ndarray, shape (ns, nt), optional + initial guess (default is indep joint density) + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterations of Sinkhorn + stopThr : float, optional + Stop threshol on the relative variation (>0) + stopThr2 : float, optional + Stop threshol on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : ndarray, shape (ns, nt) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + References + ---------- + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. + + See Also + -------- + ot.optim.cg : conditional gradient + + """ + + loop = 1 + + if log: + log = {'loss': []} + + if G0 is None: + G = np.outer(a, b) + else: + G = G0 + + def cost(G): + return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G) + + f_val = cost(G) + if log: + log['loss'].append(f_val) + + it = 0 + + if verbose: + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) + + while loop: + + it += 1 + old_fval = f_val + + # problem linearization + Mi = M + reg2 * df(G) + + # solve linear program with Sinkhorn + # Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax) + Gc = sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax) + + deltaG = Gc - G + + # line search + dcost = Mi + reg1 * (1 + np.log(G)) # ?? + alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val) + + G = G + alpha * deltaG + + # test convergence + if it >= numItermax: + loop = 0 + + abs_delta_fval = abs(f_val - old_fval) + relative_delta_fval = abs_delta_fval / abs(f_val) + + if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: + loop = 0 + + if log: + log['loss'].append(f_val) + + if verbose: + if it % 20 == 0: + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) + + if log: + return G, log + else: + return G + + +def solve_1d_linesearch_quad(a, b, c): + """ + For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem: + .. math:: + \argmin f(x)=a*x^{2}+b*x+c + + Parameters + ---------- + a,b,c : float + The coefficients of the quadratic function + + Returns + ------- + x : float + The optimal value which leads to the minimal cost + """ + f0 = c + df0 = b + f1 = a + f0 + df0 + + if a > 0: # convex + minimum = min(1, max(0, np.divide(-b, 2.0 * a))) + return minimum + else: # non convex + if f0 > f1: + return 1 + else: + return 0 |