diff options
author | aje <leo_g_autheron@hotmail.fr> | 2017-08-30 11:25:03 +0200 |
---|---|---|
committer | aje <leo_g_autheron@hotmail.fr> | 2017-08-30 11:25:03 +0200 |
commit | fadaf2ab3c3844d281b22f8d5c3404c3c4cf7d97 (patch) | |
tree | b9a8659370286820563a1fd1a9ea09ed0a9003a3 /ot/da.py | |
parent | 0316d552fa6005aaf0f6231eb9ca20441d5a2532 (diff) |
Move norm out of fit to init for deprecated OTDA
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 21 |
1 files changed, 10 insertions, 11 deletions
@@ -650,15 +650,16 @@ class OTDA(object): """ - def __init__(self, metric='sqeuclidean'): + def __init__(self, metric='sqeuclidean', norm=None): """ Class initialization""" self.xs = 0 self.xt = 0 self.G = 0 self.metric = metric + self.norm = norm self.computed = False - def fit(self, xs, xt, ws=None, wt=None, norm=None, max_iter=100000): + def fit(self, xs, xt, ws=None, wt=None, max_iter=100000): """Fit domain adaptation between samples is xs and xt (with optional weights)""" self.xs = xs @@ -673,7 +674,7 @@ class OTDA(object): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.M = cost_normalization(self.M, norm) + self.M = cost_normalization(self.M, self.norm) self.G = emd(ws, wt, self.M, max_iter) self.computed = True @@ -752,7 +753,7 @@ class OTDA_sinkhorn(OTDA): """ - def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs): + def fit(self, xs, xt, reg=1, ws=None, wt=None, **kwargs): """Fit regularized domain adaptation between samples is xs and xt (with optional weights)""" self.xs = xs @@ -767,7 +768,7 @@ class OTDA_sinkhorn(OTDA): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.M = cost_normalization(self.M, norm) + self.M = cost_normalization(self.M, self.norm) self.G = sinkhorn(ws, wt, self.M, reg, **kwargs) self.computed = True @@ -779,8 +780,7 @@ class OTDA_lpl1(OTDA): """Class for domain adaptation with optimal transport with entropic and group regularization""" - def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, - **kwargs): + def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs): """Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit parameters""" @@ -796,7 +796,7 @@ class OTDA_lpl1(OTDA): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.M = cost_normalization(self.M, norm) + self.M = cost_normalization(self.M, self.norm) self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs) self.computed = True @@ -808,8 +808,7 @@ class OTDA_l1l2(OTDA): """Class for domain adaptation with optimal transport with entropic and group lasso regularization""" - def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, - **kwargs): + def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs): """Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters""" @@ -825,7 +824,7 @@ class OTDA_l1l2(OTDA): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.M = cost_normalization(self.M, norm) + self.M = cost_normalization(self.M, self.norm) self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs) self.computed = True |