diff options
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 39 |
1 files changed, 32 insertions, 7 deletions
@@ -188,21 +188,46 @@ class OTDA(): return None - def predict(x,direction=1): + def predict(self,x,direction=1): """ Out of sample mapping using the formulation from Ferradans + It basically find the source sample the nearset to the nex sample and + apply the difference to the displaced source sample. + """ if direction>0: # >0 then source to target - G=self.G - w=self.ws.reshape((self.xs.shape[0],1)) - x=self.xt + xf=self.xt + x0=self.xs else: - G=self.G.T - w=self.wt.reshape((self.xt.shape[0],1)) - x=self.xs + xf=self.xs + x0=self.xt + + D0=dist(x,x0) # dist netween new samples an source + idx=np.argmin(D0,1) # closest one + xf=self.interp(direction)# interp the source samples + return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation + +class OTDA_sinkhorn(OTDA): + + def fit(self,xs,xt,ws=None,wt=None,reg=1,**kwargs): + """ Fit domain adaptation between samples is xs and xt (with optional + weights)""" + self.xs=xs + self.xt=xt + if wt is None: + wt=unif(xt.shape[0]) + if ws is None: + ws=unif(xs.shape[0]) + + self.ws=ws + self.wt=wt + + self.M=dist(xs,xt,metric=self.metric) + self.G=sinkhorn(ws,wt,self.M,reg,**kwargs) + self.computed=True |