From 8097bdda7c0fbd469eecb55fc0a1e93dd53b7fb8 Mon Sep 17 00:00:00 2001 From: Nicolas Courty Date: Mon, 7 Nov 2016 23:45:38 +0100 Subject: gcg --- ot/optim.py | 136 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 135 insertions(+), 1 deletion(-) (limited to 'ot/optim.py') diff --git a/ot/optim.py b/ot/optim.py index 7ed658c..598e23f 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -4,9 +4,9 @@ Optimization algorithms for OT """ import numpy as np -import scipy as sp from scipy.optimize.linesearch import scalar_search_armijo from .lp import emd +from .bregman import sinkhorn_stabilized # The corresponding scipy function does not work for matrices def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99): @@ -195,3 +195,137 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa else: return G +def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=False): + """ + Solve the general regularized OT problem with the generalized conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg1\cdot\Omega(\gamma) + reg2\cdot f(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`f` is the regularization term ( and df is its gradient) + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the generalized conditional gradient as discussed in [5,7]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples in the target domain + M : np.ndarray (ns,nt) + loss matrix + reg1 : float + Entropic Regularization term >0 + reg2 : float + Second Regularization term >0 + G0 : np.ndarray (ns,nt), optional + initial guess (default is indep joint density) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. + + See Also + -------- + ot.optim.cg : conditional gradient + + """ + + loop=1 + + if log: + log={'loss':[]} + + if G0 is None: + G=np.outer(a,b) + else: + G=G0 + + def cost(G): + return np.sum(M*G)+ reg1*np.sum(G*np.log(G)) + reg2*f(G) + + f_val=cost(G) + if log: + log['loss'].append(f_val) + + it=0 + + if verbose: + print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32) + print('{:5d}|{:8e}|{:8e}'.format(it,f_val,0)) + + while loop: + + it+=1 + old_fval=f_val + + + # problem linearization + Mi=M+reg2*df(G) + # set M positive + Mi+=Mi.min() + + # solve linear program with Sinkhorn + Gc = sinkhorn_stabilized(a,b, Mi, reg1) + + deltaG=Gc-G + + # line search + alpha,fc,f_val = line_search_armijo(cost,G,deltaG,Mi,f_val) + + G=G+alpha*deltaG + + # test convergence + if it>=numItermax: + loop=0 + + delta_fval=(f_val-old_fval)/abs(f_val) + if abs(delta_fval)