diff options
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 26 |
1 files changed, 22 insertions, 4 deletions
@@ -606,7 +606,7 @@ class OTDA(object): self.computed=False - def fit(self,xs,xt,ws=None,wt=None): + def fit(self,xs,xt,ws=None,wt=None,norm=None): """ Fit domain adaptation between samples is xs and xt (with optional weights)""" self.xs=xs self.xt=xt @@ -620,6 +620,7 @@ class OTDA(object): self.wt=wt self.M=dist(xs,xt,metric=self.metric) + self.normalize() self.G=emd(ws,wt,self.M) self.computed=True @@ -684,11 +685,25 @@ class OTDA(object): xf=self.interp(direction)# interp the source samples return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation + def normalizeM(self, norm): + """ + It may help to normalize the cost matrix self.M if there are numerical + errors during the sinkhorn based algorithms. + """ + if norm == "median": + self.M /= float(np.median(self.M)) + elif norm == "max": + self.M /= float(np.max(self.M)) + elif norm == "log": + self.M = np.log(1 + self.M) + elif norm == "loglog": + self.M = np.log(1 + np.log(1 + self.M)) + class OTDA_sinkhorn(OTDA): """Class for domain adaptation with optimal transport with entropic regularization""" - def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs): + def fit(self,xs,xt,reg=1,ws=None,wt=None,norm=None,**kwargs): """ Fit regularized domain adaptation between samples is xs and xt (with optional weights)""" self.xs=xs self.xt=xt @@ -702,6 +717,7 @@ class OTDA_sinkhorn(OTDA): self.wt=wt self.M=dist(xs,xt,metric=self.metric) + self.normalizeM(norm) self.G=sinkhorn(ws,wt,self.M,reg,**kwargs) self.computed=True @@ -710,7 +726,7 @@ class OTDA_lpl1(OTDA): """Class for domain adaptation with optimal transport with entropic and group regularization""" - def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs): + def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,norm=None,**kwargs): """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit parameters""" self.xs=xs self.xt=xt @@ -724,6 +740,7 @@ class OTDA_lpl1(OTDA): self.wt=wt self.M=dist(xs,xt,metric=self.metric) + self.normalizeM(norm) self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs) self.computed=True @@ -731,7 +748,7 @@ class OTDA_l1l2(OTDA): """Class for domain adaptation with optimal transport with entropic and group lasso regularization""" - def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs): + def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,norm=None,**kwargs): """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters""" self.xs=xs self.xt=xt @@ -745,6 +762,7 @@ class OTDA_l1l2(OTDA): self.wt=wt self.M=dist(xs,xt,metric=self.metric) + self.normalizeM(norm) self.G=sinkhorn_l1l2_gl(ws,ys,wt,self.M,reg,eta,**kwargs) self.computed=True |