From 981351165dbab740145d109b00782f0c41f2244b Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 2 Nov 2016 17:13:43 +0100 Subject: add mapping estimation (still debugging) --- ot/da.py | 228 +++++++++++++++++++++++++++++++++++++++++------------------- ot/optim.py | 106 ++++++++++++++-------------- 2 files changed, 208 insertions(+), 126 deletions(-) (limited to 'ot') diff --git a/ot/da.py b/ot/da.py index 3447437..7cfbca1 100644 --- a/ot/da.py +++ b/ot/da.py @@ -7,6 +7,7 @@ import numpy as np from .bregman import sinkhorn from .lp import emd from .utils import unif,dist +from .optim import cg def indices(a, func): @@ -15,81 +16,81 @@ def indices(a, func): def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False): """ Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization - + The function solves the following optimization problem: - + .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma) - + s.t. \gamma 1 = a - - \gamma^T 1= b - + + \gamma^T 1= b + \gamma\geq 0 where : - + - M is the (ns,nt) metric cost matrix - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain. - a and b are source and target weights (sum to 1) - + The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_ - - + + Parameters ---------- a : np.ndarray (ns,) samples weights in the source domain labels_a : np.ndarray (ns,) - labels of samples in the source domain + labels of samples in the source domain b : np.ndarray (nt,) samples in the target domain M : np.ndarray (ns,nt) - loss matrix + loss matrix reg: float Regularization term for entropic regularization >0 eta: float, optional - Regularization term for group lasso regularization >0 + Regularization term for group lasso regularization >0 numItermax: int, optional Max number of iterations numInnerItermax: int, optional Max number of iterations (inner sinkhorn solver) stopInnerThr: float, optional - Stop threshold on error (inner sinkhorn solver) (>0) + Stop threshold on error (inner sinkhorn solver) (>0) verbose : bool, optional Print information along iterations log : bool, optional - record log if True - - + record log if True + + Returns ------- gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters log: dict - log dictionary return only if log==True in parameters - - + log dictionary return only if log==True in parameters + + References ---------- - + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. - + See Also -------- ot.lp.emd : Unregularized OT ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General regularized OT - - """ + + """ p=0.5 epsilon = 1e-3 # init data Nini = len(a) Nfin = len(b) - + indices_labels = [] idx_begin = np.min(labels_a) for c in range(idx_begin,np.max(labels_a)+1): @@ -117,14 +118,96 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter # do it only for unlabbled data if idx_begin==-1: W[indices_labels[0],t]=np.min(all_maj) - + return transp +def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs): + """Joint Ot and mapping estimation (uniform weights and ) + """ + + ns,nt,d=xs.shape[0],xt.shape[0],xt.shape[1] + + if bias: + xs1=np.hstack((xs,np.ones((ns,1)))) + I=eta*np.eye(d+1) + I[-1]=0 + I0=I[:,:-1] + sel=lambda x : x[:-1,:] + else: + xs1=xs + I=eta*np.eye(d) + I0=I + sel=lambda x : x + + if log: + log={'err':[]} + + a,b=unif(ns),unif(nt) + M=dist(xs,xt) + G=emd(a,b,M) + + vloss=[] + + def loss(L,G): + return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2) + + def solve_L(G): + """ solve problem with fixed G""" + xst=ns*G.dot(xt) + return np.linalg.solve(xs1.T.dot(xs1)+I,xs1.T.dot(xst)+I0) + + def solve_G(L,G0): + xsi=xs1.dot(L) + def f(G): + return np.sum((xsi-ns*G.dot(xt))**2) + def df(G): + return -2*ns*(xsi-ns*G.dot(xt)).dot(xt.T) + G=cg(a,b,M,1.0/mu,f,df,G0=G0,numItermax=numInnerItermax,stopThr=stopInnerThr) + return G + + + L=solve_L(G) + + vloss.append(loss(L,G)) + + if verbose: + print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32) + print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0)) + + + # regul matrix + loop=1 + it=0 + + while loop: + + it+=1 + + # update G + G=solve_G(L,G) + + #update L + L=solve_L(G) + + vloss.append(loss(L,G)) + + if abs(vloss[-1]-vloss[-2])0: # >0 then source to target + if direction>0: # >0 then source to target G=self.G w=self.ws.reshape((self.xs.shape[0],1)) x=self.xt @@ -175,81 +258,80 @@ class OTDA(object): G=self.G.T w=self.wt.reshape((self.xt.shape[0],1)) x=self.xs - + if self.computed: if self.metric=='sqeuclidean': return np.dot(G/w,x) # weighted mean else: print("Warning, metric not handled yet, using weighted average") - return np.dot(G/w,x) # weighted mean - return None + return np.dot(G/w,x) # weighted mean + return None else: print("Warning, model not fitted yet, returning None") return None - - + + def predict(self,x,direction=1): - """ Out of sample mapping using the formulation from Ferradans - - It basically find the source sample the nearset to the nex sample and + """ Out of sample mapping using the formulation from Ferradans + + It basically find the source sample the nearset to the nex sample and apply the difference to the displaced source sample. - + """ - if direction>0: # >0 then source to target + if direction>0: # >0 then source to target xf=self.xt x0=self.xs else: - xf=self.xs + xf=self.xs x0=self.xt - + D0=dist(x,x0) # dist netween new samples an source idx=np.argmin(D0,1) # closest one xf=self.interp(direction)# interp the source samples return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation - - + + class OTDA_sinkhorn(OTDA): """Class for domain adaptation with optimal transport with entropic regularization""" def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs): - """ Fit domain adaptation between samples is xs and xt (with optional + """ Fit domain adaptation between samples is xs and xt (with optional weights)""" self.xs=xs self.xt=xt - + if wt is None: wt=unif(xt.shape[0]) if ws is None: ws=unif(xs.shape[0]) - + self.ws=ws self.wt=wt - + self.M=dist(xs,xt,metric=self.metric) self.G=sinkhorn(ws,wt,self.M,reg,**kwargs) - self.computed=True - - + self.computed=True + + class OTDA_lpl1(OTDA): """Class for domain adaptation with optimal transport with entropic an group regularization""" - - + + def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs): - """ Fit domain adaptation between samples is xs and xt (with optional + """ Fit domain adaptation between samples is xs and xt (with optional weights)""" self.xs=xs self.xt=xt - + if wt is None: wt=unif(xt.shape[0]) if ws is None: ws=unif(xs.shape[0]) - + self.ws=ws self.wt=wt - + self.M=dist(xs,xt,metric=self.metric) self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs) - self.computed=True - - \ No newline at end of file + self.computed=True + diff --git a/ot/optim.py b/ot/optim.py index 1afbea3..dcefd24 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -12,21 +12,21 @@ from .lp import emd def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99): """ Armijo linesearch function that works with matrices - - find an approximate minimum of f(xk+alpha*pk) that satifies the - armijo conditions. - + + find an approximate minimum of f(xk+alpha*pk) that satifies the + armijo conditions. + Parameters ---------- f : function loss function - xk : np.ndarray + xk : np.ndarray initial position - pk : np.ndarray + pk : np.ndarray descent direction gfk : np.ndarray - gradient of f at xk + gradient of f at xk old_fval: float loss value at xk args : tuple, optional @@ -35,7 +35,7 @@ def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99): c1 const in armijo rule (>0) alpha0 : float, optional initial step (>0) - + Returns ------- alpha : float @@ -44,49 +44,49 @@ def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99): nb of function call fa : float loss value at step alpha - + """ xk = np.atleast_1d(xk) fc = [0] - + def phi(alpha1): fc[0] += 1 return f(xk + alpha1*pk, *args) - + if old_fval is None: phi0 = phi(0.) else: phi0 = old_fval - + derphi0 = np.sum(pk*gfk) # Quickfix for matrices alpha,phi1 = scalar_search_armijo(phi,phi0,derphi0,c1=c1,alpha0=alpha0) - + return alpha,fc[0],phi1 def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=False): """ Solve the general regularized OT problem with conditional gradient - + The function solves the following optimization problem: - + .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg*f(\gamma) - + s.t. \gamma 1 = a - - \gamma^T 1= b - + + \gamma^T 1= b + \gamma\geq 0 where : - + - M is the (ns,nt) metric cost matrix - :math:`f` is the regularization term ( and df is its gradient) - a and b are source and target weights (sum to 1) - + The algorithm used for solving the problem is conditional gradient as discussed in [1]_ - - + + Parameters ---------- a : np.ndarray (ns,) @@ -94,7 +94,7 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa b : np.ndarray (nt,) samples in the target domain M : np.ndarray (ns,nt) - loss matrix + loss matrix reg : float Regularization term >0 G0 : np.ndarray (ns,nt), optional @@ -107,87 +107,87 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa Print information along iterations log : bool, optional record log if True - + Returns ------- gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters log: dict - log dictionary return only if log==True in parameters - - + log dictionary return only if log==True in parameters + + References ---------- - + .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. - + See Also -------- ot.lp.emd : Unregularized optimal ransport ot.bregman.sinkhorn : Entropic regularized optimal transport - + """ - + loop=1 - + if log: log={'loss':[]} - + if G0 is None: G=np.outer(a,b) else: G=G0 - + def cost(G): return np.sum(M*G)+reg*f(G) - + f_val=cost(G) if log: log['loss'].append(f_val) - + it=0 - + if verbose: print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32) print('{:5d}|{:8e}|{:8e}'.format(it,f_val,0)) - + while loop: - + it+=1 old_fval=f_val - - + + # problem linearization Mi=M+reg*df(G) - + # solve linear program Gc=emd(a,b,Mi) - + deltaG=Gc-G - + # line search alpha,fc,f_val = line_search_armijo(cost,G,deltaG,Mi,f_val) - + G=G+alpha*deltaG - + # test convergence if it>=numItermax: loop=0 - + delta_fval=(f_val-old_fval)/abs(f_val) if abs(delta_fval)