diff options
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 48 |
1 files changed, 44 insertions, 4 deletions
@@ -247,7 +247,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b def loss(L,G): """Compute full loss""" - return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L)**2) + return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(K0).dot(L)) def solve_L_nobias(G): """ solve L problem with fixed G (least square)""" @@ -450,11 +450,11 @@ class OTDA_lpl1(OTDA): self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs) self.computed=True -class OTDA_mapping(OTDA): +class OTDA_mapping_linear(OTDA): """Class for optimal transport with joint linear mapping estimation""" - def __init__(self,metric='sqeuclidean'): + def __init__(self): """ Class initialization""" @@ -463,8 +463,8 @@ class OTDA_mapping(OTDA): self.G=0 self.L=0 self.bias=False - self.metric=metric self.computed=False + self.metric='sqeuclidean' def fit(self,xs,xt,mu=1,eta=1,bias=False,**kwargs): """ Fit domain adaptation between samples is xs and xt (with optional @@ -473,6 +473,7 @@ class OTDA_mapping(OTDA): self.xt=xt self.bias=bias + self.ws=unif(xs.shape[0]) self.wt=unif(xt.shape[0]) @@ -498,3 +499,42 @@ class OTDA_mapping(OTDA): print("Warning, model not fitted yet, returning None") return None +class OTDA_mapping_kernel(OTDA_mapping_linear): + """Class for optimal transport with joint linear mapping estimation""" + + + + def fit(self,xs,xt,mu=1,eta=1,bias=False,kerneltype='gaussian',sigma=1,**kwargs): + """ Fit domain adaptation between samples is xs and xt (with optional + weights)""" + self.xs=xs + self.xt=xt + self.bias=bias + + self.ws=unif(xs.shape[0]) + self.wt=unif(xt.shape[0]) + self.kernel=kerneltype + self.sigma=sigma + self.kwargs=kwargs + + + self.G,self.L=joint_OT_mapping_kernel(xs,xt,mu=mu,eta=eta,bias=bias,**kwargs) + self.computed=True + + + def predict(self,x): + """ Out of sample mapping using the formulation from Ferradans + + It basically find the source sample the nearset to the nex sample and + apply the difference to the displaced source sample. + + """ + + if self.computed: + K=kernel(x,self.xs,method=self.kernel,sigma=self.sigma,**self.kwargs) + if self.bias: + K=np.hstack((K,np.ones((x.shape[0],1)))) + return K.dot(self.L) + else: + print("Warning, model not fitted yet, returning None") + return None
\ No newline at end of file |