diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-02-16 13:53:34 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-02-16 13:53:34 +0100 |
commit | f089a3cbc27c30ba9416ea1659c2fdbac1857146 (patch) | |
tree | 04905ed08da51190f15faa7aff67acb308ee583f /ot/gpu/da.py | |
parent | 4a585de94109102c89bcd7ad43e35772e1027cd2 (diff) |
better pep8 but not solved
Diffstat (limited to 'ot/gpu/da.py')
-rw-r--r-- | ot/gpu/da.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/ot/gpu/da.py b/ot/gpu/da.py index 05c580f..71a485a 100644 --- a/ot/gpu/da.py +++ b/ot/gpu/da.py @@ -188,6 +188,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M_GPU, reg, eta=0.1, numItermax=10, class OTDA_GPU(OTDA): + def normalizeM(self, norm): if norm == "median": self.M_GPU.divide(float(np.median(self.M_GPU.asarray()))) @@ -204,6 +205,7 @@ class OTDA_GPU(OTDA): class OTDA_sinkhorn(OTDA_GPU): + def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs): cudamat.init() xs = np.asarray(xs, dtype=np.float64) @@ -228,6 +230,7 @@ class OTDA_sinkhorn(OTDA_GPU): class OTDA_lpl1(OTDA_GPU): + def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, **kwargs): cudamat.init() |