diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-31 14:36:34 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-31 14:36:34 +0100 |
commit | fdceacbac0d4b0380c90ca2c942e5abd0f69df64 (patch) | |
tree | ffe8fe373c7f7ae2f65172e820230c57762da84f | |
parent | 104627b2f69eb22d3f9010955e6765ac2b179faa (diff) |
add classes for entropic and group lasso regularization
-rw-r--r-- | ot/da.py | 25 |
1 files changed, 23 insertions, 2 deletions
@@ -210,8 +210,8 @@ class OTDA(): class OTDA_sinkhorn(OTDA): - - def fit(self,xs,xt,ws=None,wt=None,reg=1,**kwargs): + """Class for domain adaptation with optimal transport with entropic regularization""" + def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs): """ Fit domain adaptation between samples is xs and xt (with optional weights)""" self.xs=xs @@ -230,5 +230,26 @@ class OTDA_sinkhorn(OTDA): self.computed=True +class OTDA_lpl1(OTDA): + """Class for domain adaptation with optimal transport with entropic an group regularization""" + + def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs): + """ Fit domain adaptation between samples is xs and xt (with optional + weights)""" + self.xs=xs + self.xt=xt + + if wt is None: + wt=unif(xt.shape[0]) + if ws is None: + ws=unif(xs.shape[0]) + + self.ws=ws + self.wt=wt + + self.M=dist(xs,xt,metric=self.metric) + self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs) + self.computed=True +
\ No newline at end of file |