diff options
-rw-r--r-- | ot/da.py | 21 |
1 files changed, 10 insertions, 11 deletions
@@ -650,15 +650,16 @@ class OTDA(object): """ - def __init__(self, metric='sqeuclidean'): + def __init__(self, metric='sqeuclidean', norm=None): """ Class initialization""" self.xs = 0 self.xt = 0 self.G = 0 self.metric = metric + self.norm = norm self.computed = False - def fit(self, xs, xt, ws=None, wt=None, norm=None, max_iter=100000): + def fit(self, xs, xt, ws=None, wt=None, max_iter=100000): """Fit domain adaptation between samples is xs and xt (with optional weights)""" self.xs = xs @@ -673,7 +674,7 @@ class OTDA(object): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.M = cost_normalization(self.M, norm) + self.M = cost_normalization(self.M, self.norm) self.G = emd(ws, wt, self.M, max_iter) self.computed = True @@ -752,7 +753,7 @@ class OTDA_sinkhorn(OTDA): """ - def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs): + def fit(self, xs, xt, reg=1, ws=None, wt=None, **kwargs): """Fit regularized domain adaptation between samples is xs and xt (with optional weights)""" self.xs = xs @@ -767,7 +768,7 @@ class OTDA_sinkhorn(OTDA): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.M = cost_normalization(self.M, norm) + self.M = cost_normalization(self.M, self.norm) self.G = sinkhorn(ws, wt, self.M, reg, **kwargs) self.computed = True @@ -779,8 +780,7 @@ class OTDA_lpl1(OTDA): """Class for domain adaptation with optimal transport with entropic and group regularization""" - def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, - **kwargs): + def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs): """Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit parameters""" @@ -796,7 +796,7 @@ class OTDA_lpl1(OTDA): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.M = cost_normalization(self.M, norm) + self.M = cost_normalization(self.M, self.norm) self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs) self.computed = True @@ -808,8 +808,7 @@ class OTDA_l1l2(OTDA): """Class for domain adaptation with optimal transport with entropic and group lasso regularization""" - def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, - **kwargs): + def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs): """Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters""" @@ -825,7 +824,7 @@ class OTDA_l1l2(OTDA): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.M = cost_normalization(self.M, norm) + self.M = cost_normalization(self.M, self.norm) self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs) self.computed = True |