diff options
-rw-r--r-- | ot/da.py | 25 |
1 files changed, 23 insertions, 2 deletions
@@ -210,8 +210,8 @@ class OTDA(): class OTDA_sinkhorn(OTDA): - - def fit(self,xs,xt,ws=None,wt=None,reg=1,**kwargs): + """Class for domain adaptation with optimal transport with entropic regularization""" + def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs): """ Fit domain adaptation between samples is xs and xt (with optional weights)""" self.xs=xs @@ -230,5 +230,26 @@ class OTDA_sinkhorn(OTDA): self.computed=True +class OTDA_lpl1(OTDA): + """Class for domain adaptation with optimal transport with entropic an group regularization""" + + def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs): + """ Fit domain adaptation between samples is xs and xt (with optional + weights)""" + self.xs=xs + self.xt=xt + + if wt is None: + wt=unif(xt.shape[0]) + if ws is None: + ws=unif(xs.shape[0]) + + self.ws=ws + self.wt=wt + + self.M=dist(xs,xt,metric=self.metric) + self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs) + self.computed=True +
\ No newline at end of file |