summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authoraje <leo_g_autheron@hotmail.fr>2017-08-30 11:25:03 +0200
committeraje <leo_g_autheron@hotmail.fr>2017-08-30 11:25:03 +0200
commitfadaf2ab3c3844d281b22f8d5c3404c3c4cf7d97 (patch)
treeb9a8659370286820563a1fd1a9ea09ed0a9003a3 /ot/da.py
parent0316d552fa6005aaf0f6231eb9ca20441d5a2532 (diff)
Move norm out of fit to init for deprecated OTDA
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py21
1 files changed, 10 insertions, 11 deletions
diff --git a/ot/da.py b/ot/da.py
index 61a3ba0..564c7b7 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -650,15 +650,16 @@ class OTDA(object):
"""
- def __init__(self, metric='sqeuclidean'):
+ def __init__(self, metric='sqeuclidean', norm=None):
""" Class initialization"""
self.xs = 0
self.xt = 0
self.G = 0
self.metric = metric
+ self.norm = norm
self.computed = False
- def fit(self, xs, xt, ws=None, wt=None, norm=None, max_iter=100000):
+ def fit(self, xs, xt, ws=None, wt=None, max_iter=100000):
"""Fit domain adaptation between samples is xs and xt
(with optional weights)"""
self.xs = xs
@@ -673,7 +674,7 @@ class OTDA(object):
self.wt = wt
self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, norm)
+ self.M = cost_normalization(self.M, self.norm)
self.G = emd(ws, wt, self.M, max_iter)
self.computed = True
@@ -752,7 +753,7 @@ class OTDA_sinkhorn(OTDA):
"""
- def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
+ def fit(self, xs, xt, reg=1, ws=None, wt=None, **kwargs):
"""Fit regularized domain adaptation between samples is xs and xt
(with optional weights)"""
self.xs = xs
@@ -767,7 +768,7 @@ class OTDA_sinkhorn(OTDA):
self.wt = wt
self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, norm)
+ self.M = cost_normalization(self.M, self.norm)
self.G = sinkhorn(ws, wt, self.M, reg, **kwargs)
self.computed = True
@@ -779,8 +780,7 @@ class OTDA_lpl1(OTDA):
"""Class for domain adaptation with optimal transport with entropic and
group regularization"""
- def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
- **kwargs):
+ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
"""Fit regularized domain adaptation between samples is xs and xt
(with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
parameters"""
@@ -796,7 +796,7 @@ class OTDA_lpl1(OTDA):
self.wt = wt
self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, norm)
+ self.M = cost_normalization(self.M, self.norm)
self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs)
self.computed = True
@@ -808,8 +808,7 @@ class OTDA_l1l2(OTDA):
"""Class for domain adaptation with optimal transport with entropic
and group lasso regularization"""
- def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
- **kwargs):
+ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
"""Fit regularized domain adaptation between samples is xs and xt
(with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
parameters"""
@@ -825,7 +824,7 @@ class OTDA_l1l2(OTDA):
self.wt = wt
self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, norm)
+ self.M = cost_normalization(self.M, self.norm)
self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs)
self.computed = True