diff options
Diffstat (limited to 'ot/optim.py')
-rw-r--r-- | ot/optim.py | 102 |
1 files changed, 99 insertions, 3 deletions
diff --git a/ot/optim.py b/ot/optim.py index f31fae2..a774865 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -4,7 +4,7 @@ Optimization algorithms for OT """ # Author: Remi Flamary <remi.flamary@unice.fr> -# +# Titouan Vayer <titouan.vayer@irisa.fr> # License: MIT License import numpy as np @@ -71,9 +71,70 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, return alpha, fc[0], phi1 +def do_linesearch(cost,G,deltaG,Mi,f_val, + amijo=False,C1=None,C2=None,reg=None,Gc=None,constC=None,M=None): + """ + Solve the linesearch in the FW iterations + Parameters + ---------- + cost : method + The FGW cost + G : ndarray, shape(ns,nt) + The transport map at a given iteration of the FW + deltaG : ndarray (ns,nt) + Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration + Mi : ndarray (ns,nt) + Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost + f_val : float + Value of the cost at G + amijo : bool, optionnal + If True the steps of the line-search is found via an amijo research. Else closed form is used. + If there is convergence issues use False. + C1 : ndarray (ns,ns), optionnal + Structure matrix in the source domain. Only used when amijo=False + C2 : ndarray (nt,nt), optionnal + Structure matrix in the target domain. Only used when amijo=False + reg : float, optionnal + Regularization parameter. Corresponds to the alpha parameter of FGW. Only used when amijo=False + Gc : ndarray (ns,nt) + Optimal map found by linearization in the FW algorithm. Only used when amijo=False + constC : ndarray (ns,nt) + Constant for the gromov cost. See [3]. Only used when amijo=False + M : ndarray (ns,nt), optionnal + Cost matrix between the features. Only used when amijo=False + Returns + ------- + alpha : float + The optimal step size of the FW + fc : int + nb of function call. Useless here + f_val : float + The value of the cost for the next iteration + References + ---------- + .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + if amijo: + alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) + else: # requires symetric matrices + dot1=np.dot(C1,deltaG) + dot12=dot1.dot(C2) + a=-2*reg*np.sum(dot12*deltaG) + b=np.sum((M+reg*constC)*deltaG)-2*reg*(np.sum(dot12*G)+np.sum(np.dot(C1,G).dot(C2)*deltaG)) + c=cost(G) + + alpha=solve_1d_linesearch_quad_funct(a,b,c) + fc=None + f_val=cost(G+alpha*deltaG) + + return alpha,fc,f_val + def cg(a, b, M, reg, f, df, G0=None, numItermax=200, - stopThr=1e-9, verbose=False, log=False): + stopThr=1e-9, verbose=False, log=False,**kwargs): """ Solve the general regularized OT problem with conditional gradient @@ -116,6 +177,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, Print information along iterations log : bool, optional record log if True + kwargs : dict + Parameters for linesearch Returns ------- @@ -177,7 +240,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, deltaG = Gc - G # line search - alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) + alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,**kwargs) G = G + alpha * deltaG @@ -339,3 +402,36 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, return G, log else: return G + +def solve_1d_linesearch_quad_funct(a,b,c): + """ + Solve on 0,1 the following problem: + .. math:: + \min f(x)=a*x^{2}+b*x+c + + Parameters + ---------- + a,b,c : float + The coefficients of the quadratic function + + Returns + ------- + x : float + The optimal value which leads to the minimal cost + + """ + f0=c + df0=b + f1=a+f0+df0 + + if a>0: # convex + minimum=min(1,max(0,-b/(2*a))) + #print('entrelesdeux') + return minimum + else: # non convexe donc sur les coins + if f0>f1: + #print('sur1 f(1)={}'.format(f(1))) + return 1 + else: + #print('sur0 f(0)={}'.format(f(0))) + return 0 |