diff options
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 228 |
1 files changed, 155 insertions, 73 deletions
@@ -7,6 +7,7 @@ import numpy as np from .bregman import sinkhorn from .lp import emd from .utils import unif,dist +from .optim import cg def indices(a, func): @@ -15,81 +16,81 @@ def indices(a, func): def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False): """ Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization - + The function solves the following optimization problem: - + .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma) - + s.t. \gamma 1 = a - - \gamma^T 1= b - + + \gamma^T 1= b + \gamma\geq 0 where : - + - M is the (ns,nt) metric cost matrix - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain. - a and b are source and target weights (sum to 1) - + The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_ - - + + Parameters ---------- a : np.ndarray (ns,) samples weights in the source domain labels_a : np.ndarray (ns,) - labels of samples in the source domain + labels of samples in the source domain b : np.ndarray (nt,) samples in the target domain M : np.ndarray (ns,nt) - loss matrix + loss matrix reg: float Regularization term for entropic regularization >0 eta: float, optional - Regularization term for group lasso regularization >0 + Regularization term for group lasso regularization >0 numItermax: int, optional Max number of iterations numInnerItermax: int, optional Max number of iterations (inner sinkhorn solver) stopInnerThr: float, optional - Stop threshold on error (inner sinkhorn solver) (>0) + Stop threshold on error (inner sinkhorn solver) (>0) verbose : bool, optional Print information along iterations log : bool, optional - record log if True - - + record log if True + + Returns ------- gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters log: dict - log dictionary return only if log==True in parameters - - + log dictionary return only if log==True in parameters + + References ---------- - + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. - + See Also -------- ot.lp.emd : Unregularized OT ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General regularized OT - - """ + + """ p=0.5 epsilon = 1e-3 # init data Nini = len(a) Nfin = len(b) - + indices_labels = [] idx_begin = np.min(labels_a) for c in range(idx_begin,np.max(labels_a)+1): @@ -117,14 +118,96 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter # do it only for unlabbled data if idx_begin==-1: W[indices_labels[0],t]=np.min(all_maj) - + return transp +def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs): + """Joint Ot and mapping estimation (uniform weights and ) + """ + + ns,nt,d=xs.shape[0],xt.shape[0],xt.shape[1] + + if bias: + xs1=np.hstack((xs,np.ones((ns,1)))) + I=eta*np.eye(d+1) + I[-1]=0 + I0=I[:,:-1] + sel=lambda x : x[:-1,:] + else: + xs1=xs + I=eta*np.eye(d) + I0=I + sel=lambda x : x + + if log: + log={'err':[]} + + a,b=unif(ns),unif(nt) + M=dist(xs,xt) + G=emd(a,b,M) + + vloss=[] + + def loss(L,G): + return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2) + + def solve_L(G): + """ solve problem with fixed G""" + xst=ns*G.dot(xt) + return np.linalg.solve(xs1.T.dot(xs1)+I,xs1.T.dot(xst)+I0) + + def solve_G(L,G0): + xsi=xs1.dot(L) + def f(G): + return np.sum((xsi-ns*G.dot(xt))**2) + def df(G): + return -2*ns*(xsi-ns*G.dot(xt)).dot(xt.T) + G=cg(a,b,M,1.0/mu,f,df,G0=G0,numItermax=numInnerItermax,stopThr=stopInnerThr) + return G + + + L=solve_L(G) + + vloss.append(loss(L,G)) + + if verbose: + print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32) + print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0)) + + + # regul matrix + loop=1 + it=0 + + while loop: + + it+=1 + + # update G + G=solve_G(L,G) + + #update L + L=solve_L(G) + + vloss.append(loss(L,G)) + + if abs(vloss[-1]-vloss[-2])<stopThr: + loop=0 + + if verbose: + if it%20==0: + print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32) + print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2]))) + + return G,L + + + class OTDA(object): """Class for domain adaptation with optimal transport""" - + def __init__(self,metric='sqeuclidean'): """ Class initialization""" self.xs=0 @@ -132,42 +215,42 @@ class OTDA(object): self.G=0 self.metric=metric self.computed=False - - + + def fit(self,xs,xt,ws=None,wt=None): - """ Fit domain adaptation between samples is xs and xt (with optional + """ Fit domain adaptation between samples is xs and xt (with optional weights)""" self.xs=xs self.xt=xt - + if wt is None: wt=unif(xt.shape[0]) if ws is None: ws=unif(xs.shape[0]) - + self.ws=ws self.wt=wt - + self.M=dist(xs,xt,metric=self.metric) self.G=emd(ws,wt,self.M) self.computed=True - + def interp(self,direction=1): """Barycentric interpolation for the source (1) or target (-1) - - This Barycentric interpolation solves for each source (resp target) + + This Barycentric interpolation solves for each source (resp target) sample xs (resp xt) the following optimization problem: - + .. math:: arg\min_x \sum_i \gamma_{k,i} c(x,x_i^t) - + where k is the index of the sample in xs - - For the moment only squared euclidean distance is provided but more - metric could be used in the future. - + + For the moment only squared euclidean distance is provided but more + metric could be used in the future. + """ - if direction>0: # >0 then source to target + if direction>0: # >0 then source to target G=self.G w=self.ws.reshape((self.xs.shape[0],1)) x=self.xt @@ -175,81 +258,80 @@ class OTDA(object): G=self.G.T w=self.wt.reshape((self.xt.shape[0],1)) x=self.xs - + if self.computed: if self.metric=='sqeuclidean': return np.dot(G/w,x) # weighted mean else: print("Warning, metric not handled yet, using weighted average") - return np.dot(G/w,x) # weighted mean - return None + return np.dot(G/w,x) # weighted mean + return None else: print("Warning, model not fitted yet, returning None") return None - - + + def predict(self,x,direction=1): - """ Out of sample mapping using the formulation from Ferradans - - It basically find the source sample the nearset to the nex sample and + """ Out of sample mapping using the formulation from Ferradans + + It basically find the source sample the nearset to the nex sample and apply the difference to the displaced source sample. - + """ - if direction>0: # >0 then source to target + if direction>0: # >0 then source to target xf=self.xt x0=self.xs else: - xf=self.xs + xf=self.xs x0=self.xt - + D0=dist(x,x0) # dist netween new samples an source idx=np.argmin(D0,1) # closest one xf=self.interp(direction)# interp the source samples return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation - - + + class OTDA_sinkhorn(OTDA): """Class for domain adaptation with optimal transport with entropic regularization""" def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs): - """ Fit domain adaptation between samples is xs and xt (with optional + """ Fit domain adaptation between samples is xs and xt (with optional weights)""" self.xs=xs self.xt=xt - + if wt is None: wt=unif(xt.shape[0]) if ws is None: ws=unif(xs.shape[0]) - + self.ws=ws self.wt=wt - + self.M=dist(xs,xt,metric=self.metric) self.G=sinkhorn(ws,wt,self.M,reg,**kwargs) - self.computed=True - - + self.computed=True + + class OTDA_lpl1(OTDA): """Class for domain adaptation with optimal transport with entropic an group regularization""" - - + + def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs): - """ Fit domain adaptation between samples is xs and xt (with optional + """ Fit domain adaptation between samples is xs and xt (with optional weights)""" self.xs=xs self.xt=xt - + if wt is None: wt=unif(xt.shape[0]) if ws is None: ws=unif(xs.shape[0]) - + self.ws=ws self.wt=wt - + self.M=dist(xs,xt,metric=self.metric) self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs) - self.computed=True - -
\ No newline at end of file + self.computed=True + |