diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/da.py | 116 | ||||
-rw-r--r-- | ot/optim.py | 13 |
2 files changed, 124 insertions, 5 deletions
@@ -8,6 +8,7 @@ from .bregman import sinkhorn from .lp import emd from .utils import unif,dist,kernel from .optim import cg +from .optim import gcg def indices(a, func): @@ -122,6 +123,100 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter return transp +def sinkhorn_l1l2_gl(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False): + """ + Solve the entropic regularization optimal transport problem with group lasso regularization + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain. + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + labels_a : np.ndarray (ns,) + labels of samples in the source domain + b : np.ndarray (nt,) + samples in the target domain + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term for entropic regularization >0 + eta : float, optional + Regularization term for group lasso regularization >0 + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterations (inner sinkhorn solver) + stopInnerThr : float, optional + Stop threshold on error (inner sinkhorn solver) (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. + + See Also + -------- + ot.optim.gcg : Generalized conditional gradient for OT problems + + """ + lstlab=np.unique(labels_a) + + def f(G): + res=0 + for i in range(G.shape[1]): + for lab in lstlab: + temp=G[labels_a==lab,i] + res+=np.linalg.norm(temp) + return res + + def df(G): + W=np.zeros(G.shape) + for i in range(G.shape[1]): + for lab in lstlab: + temp=G[labels_a==lab,i] + n=np.linalg.norm(temp) + if n: + W[labels_a==lab,i]=temp/n + return W + + + return gcg(a,b,M,reg,eta,f,df,G0=None,numItermax = numItermax,numInnerItermax=numInnerItermax, stopThr=stopInnerThr,verbose=verbose,log=log) + + + def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs): """Joint OT and linear mapping estimation as proposed in [8] @@ -632,6 +727,27 @@ class OTDA_lpl1(OTDA): self.M=dist(xs,xt,metric=self.metric) self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs) self.computed=True + +class OTDA_l1l2(OTDA): + """Class for domain adaptation with optimal transport with entropic and group lasso regularization""" + + + def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs): + """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters""" + self.xs=xs + self.xt=xt + + if wt is None: + wt=unif(xt.shape[0]) + if ws is None: + ws=unif(xs.shape[0]) + + self.ws=ws + self.wt=wt + + self.M=dist(xs,xt,metric=self.metric) + self.G=sinkhorn_l1l2_gl(ws,ys,wt,self.M,reg,eta,**kwargs) + self.computed=True class OTDA_mapping_linear(OTDA): """Class for optimal transport with joint linear mapping estimation as in [8]""" diff --git a/ot/optim.py b/ot/optim.py index 598e23f..d807824 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -7,6 +7,7 @@ import numpy as np from scipy.optimize.linesearch import scalar_search_armijo from .lp import emd from .bregman import sinkhorn_stabilized +from .bregman import sinkhorn # The corresponding scipy function does not work for matrices def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99): @@ -195,7 +196,7 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa else: return G -def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=False): +def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 10,numInnerItermax = 200,stopThr=1e-9,verbose=False,log=False): """ Solve the general regularized OT problem with the generalized conditional gradient @@ -235,6 +236,8 @@ def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False initial guess (default is indep joint density) numItermax : int, optional Max number of iterations + numInnerItermax : int, optional + Max number of iterations of Sinkhorn stopThr : float, optional Stop threshol on error (>0) verbose : bool, optional @@ -293,16 +296,16 @@ def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False # problem linearization Mi=M+reg2*df(G) - # set M positive - Mi+=Mi.min() # solve linear program with Sinkhorn - Gc = sinkhorn_stabilized(a,b, Mi, reg1) + #Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax) + Gc = sinkhorn(a,b, Mi, reg1, numItermax = numInnerItermax) deltaG=Gc-G # line search - alpha,fc,f_val = line_search_armijo(cost,G,deltaG,Mi,f_val) + dcost=Mi+reg1*np.sum(deltaG*(1+np.log(G))) #?? + alpha,fc,f_val = line_search_armijo(cost,G,deltaG,dcost,f_val) G=G+alpha*deltaG |