From 93dee553a3dd5d6e3c5a5d325bb6333e8eb24dee Mon Sep 17 00:00:00 2001 From: aje Date: Wed, 30 Aug 2017 11:25:03 +0200 Subject: Move norm out of fit to init for deprecated OTDA --- ot/da.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) (limited to 'ot') diff --git a/ot/da.py b/ot/da.py index 61a3ba0..564c7b7 100644 --- a/ot/da.py +++ b/ot/da.py @@ -650,15 +650,16 @@ class OTDA(object): """ - def __init__(self, metric='sqeuclidean'): + def __init__(self, metric='sqeuclidean', norm=None): """ Class initialization""" self.xs = 0 self.xt = 0 self.G = 0 self.metric = metric + self.norm = norm self.computed = False - def fit(self, xs, xt, ws=None, wt=None, norm=None, max_iter=100000): + def fit(self, xs, xt, ws=None, wt=None, max_iter=100000): """Fit domain adaptation between samples is xs and xt (with optional weights)""" self.xs = xs @@ -673,7 +674,7 @@ class OTDA(object): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.M = cost_normalization(self.M, norm) + self.M = cost_normalization(self.M, self.norm) self.G = emd(ws, wt, self.M, max_iter) self.computed = True @@ -752,7 +753,7 @@ class OTDA_sinkhorn(OTDA): """ - def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs): + def fit(self, xs, xt, reg=1, ws=None, wt=None, **kwargs): """Fit regularized domain adaptation between samples is xs and xt (with optional weights)""" self.xs = xs @@ -767,7 +768,7 @@ class OTDA_sinkhorn(OTDA): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.M = cost_normalization(self.M, norm) + self.M = cost_normalization(self.M, self.norm) self.G = sinkhorn(ws, wt, self.M, reg, **kwargs) self.computed = True @@ -779,8 +780,7 @@ class OTDA_lpl1(OTDA): """Class for domain adaptation with optimal transport with entropic and group regularization""" - def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, - **kwargs): + def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs): """Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit parameters""" @@ -796,7 +796,7 @@ class OTDA_lpl1(OTDA): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.M = cost_normalization(self.M, norm) + self.M = cost_normalization(self.M, self.norm) self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs) self.computed = True @@ -808,8 +808,7 @@ class OTDA_l1l2(OTDA): """Class for domain adaptation with optimal transport with entropic and group lasso regularization""" - def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, - **kwargs): + def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs): """Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters""" @@ -825,7 +824,7 @@ class OTDA_l1l2(OTDA): self.wt = wt self.M = dist(xs, xt, metric=self.metric) - self.M = cost_normalization(self.M, norm) + self.M = cost_normalization(self.M, self.norm) self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs) self.computed = True -- cgit v1.2.3